xllm源码分析(三)——推理流程
分析xllm模型的推理过程,一致到算子层接口
LLMEngine::step
ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
// update dp related global paramters and then execute model
for (auto worker_rank = 0; worker_rank < worker_clients_num_; ++worker_rank) {
auto dp_rank = worker_rank / dp_local_tp_size_;
futures.emplace_back(
worker_clients_[worker_rank]->step_async(raw_forward_inputs[dp_rank]));
}
}
step_async
folly::SemiFuture<std::optional<RawForwardOutput>> RemoteWorker::step_async(
const RawForwardInput& inputs) {
folly::Promise<std::optional<RawForwardOutput>> promise;
auto future = promise.getSemiFuture();
threadpool_.schedule([this,
inputs = std::move(inputs),
promise = std::move(promise)]() mutable {
channel_->execute_model_async({inputs}, promise);
});
return future;
}
channel_->execute_model_async
void CommChannel::execute_model_async(
const std::vector<RawForwardInput>& inputs,
folly::Promise<std::optional<RawForwardOutput>>& promise) {
execute_model_with_brpc(inputs, promise);
}
bool CommChannel::execute_model_with_brpc(
const std::vector<RawForwardInput>& inputs,
folly::Promise<std::optional<RawForwardOutput>>& promise) {
// convert to proto::ForwardInput
proto::ForwardInput pb_forward_input;
forward_input_to_proto(inputs[0], &pb_forward_input);
// call ExecuteModel with callback
auto done = new ExecuteModelClosure();
done->promise = std::move(promise);
stub_->ExecuteModel(&done->cntl, &pb_forward_input, &done->pb_output, done);
return true;
}
stub_->ExecuteModel是rpc调用。
远端的ExecuteModel
void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
const proto::ForwardInput* pb_forward_input,
proto::ForwardOutput* pb_forward_output,
::google::protobuf::Closure* done) {
threadpool_->schedule(
[this, controller, pb_forward_input, pb_forward_output, done]() mutable {
brpc::ClosureGuard done_guard(done);
// convert proto::ForwardInput to ForwardInput
Timer timer;
ForwardInput forward_input;
proto_to_forward_input(
pb_forward_input, forward_input, options_.num_decoding_tokens());
// model output
torch::Tensor next_tokens;
torch::Tensor logprobs;
torch::Tensor top_tokens;
torch::Tensor top_logprobs;
torch::Tensor embeddings;
torch::Tensor expert_load_data;
int32_t prepared_layer_id = -1;
// beam search kernel output
torch::Tensor src_seq_idxes;
torch::Tensor out_tokens;
torch::Tensor out_logprobs;
step(forward_input,
next_tokens,
logprobs,
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
src_seq_idxes,
out_tokens,
out_logprobs);
// convert to proto output
forward_output_to_proto(next_tokens,
logprobs,
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
src_seq_idxes,
out_tokens,
out_logprobs,
pb_forward_output);
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
});
}
void WorkerService::step(ForwardInput& fwd_input,
torch::Tensor& next_tokens,
torch::Tensor& logprobs,
torch::Tensor& top_tokens,
torch::Tensor& top_logprobs,
torch::Tensor& embeddings,
torch::Tensor& expert_load_data,
int32_t& prepared_layer_id,
torch::Tensor& src_seq_idxes,
torch::Tensor& out_tokens,
torch::Tensor& out_logprobs) {
// execute model
auto future = worker_->step_async(fwd_input);
if (!options_.enable_schedule_overlap()) {
}
}
worker_->step_async
folly::SemiFuture<std::optional<ForwardOutput>> Worker::step_async(
const ForwardInput& inputs) {
return impl_->step_async(inputs);
}
folly::SemiFuture<std::optional<ForwardOutput>> WorkerImpl::step_async(
const ForwardInput& input) {
ForwardInput input_on_device;
prepare_work_before_execute(input, input_on_device);
folly::Promise<std::optional<ForwardOutput>> promise;
auto future = promise.getSemiFuture();
threadpool_.schedule([this,
input = std::move(input_on_device),
promise = std::move(promise)]() mutable {
if (hierarchy_kv_cache_transfer_ != nullptr) {
hierarchy_kv_cache_transfer_->set_layer_synchronizer(input.input_params);
}
// run the model on the given input in working thread
if (!enable_schedule_overlap()) {
const auto output = this->step(input);
promise.setValue(output);
} else {
if (last_step_output_valid_ && !input.input_params.empty_kv_cache) {
// replace step i model input with true output of step i-1
input = update_input_by_last_step_output(input);
}
const auto output = this->step(input);
if (output.has_value()) {
if (is_driver() || FLAGS_enable_eplb) {
std::unique_lock<std::mutex> lock(mtx_);
cv_.wait(lock, [this] { return !is_recorded_; });
update_last_step_output(output);
is_recorded_ = true;
cv_.notify_one();
} else {
update_last_step_output(output);
}
} else {
if (is_driver() || FLAGS_enable_eplb) {
std::unique_lock<std::mutex> lock(mtx_);
cv_.wait(lock, [this] { return !is_recorded_; });
last_step_output_valid_ = false;
is_recorded_ = true;
cv_.notify_one();
} else {
last_step_output_valid_ = false;
}
}
promise.setValue(output);
}
});
return future;
}
LLMWorkerImpl::step
std::optional<ForwardOutput> LLMWorkerImpl::step(const ForwardInput& input) {
Timer timer;
auto& sampling_params = input.sampling_params;
// temporarily use [0], will be adapted in next pr
// call model executor forward to get hidden states
auto hidden_states = model_executor_->forward(
input.token_ids, input.positions, kv_caches_, input.input_params);
if (!hidden_states.defined()) {
return std::nullopt;
}
torch::Tensor logits;
if (sampling_params.selected_token_idxes.defined()) {
logits =
model_->logits(hidden_states, sampling_params.selected_token_idxes);
}
ForwardOutput output;
// driver prepare model output
SampleOutput sample_output;
if (sampling_params.selected_token_idxes.defined()) {
sample_output = sampler_->forward(logits, sampling_params);
output.logits = logits;
// beam search kernel
BeamSearchOutput beam_search_output;
if (sampling_params.use_beam_search && input.acc_logprob.defined() &&
input.acc_logprob.numel() > 0) {
beam_search_output = beam_searcher_->forward(input.acc_logprob,
sample_output.top_tokens,
sample_output.top_logprobs);
}
// set sample output to output
output.sample_output = sample_output;
// carry over the sampling params
output.do_sample = sampling_params.do_sample;
output.logprobs = sampling_params.logprobs;
output.max_top_logprobs = sampling_params.max_top_logprobs;
// set beam search output to output
output.beam_search_output = beam_search_output;
}
if (options_.enable_speculative_decode()) {
if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) {
output.sample_output.embeddings = hidden_states;
} else if (sampling_params.selected_token_idxes.defined()) {
auto embeddings = hidden_states.index_select(
/*dim=*/0, sampling_params.selected_token_idxes);
output.sample_output.embeddings = embeddings;
}
}
return output;
}
model_executor_->forward
torch::Tensor Executor::forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& params) {
return impl_->run(tokens, positions, kv_caches, params);
}
Executor::Executor(CausalLM* model,
const ModelArgs& args,
const torch::Device& device,
const runtime::Options& options) {
impl_ = ExecutorImplFactory::get_instance().create_executor_impl(
model, args, device, options);
}
std::unique_ptr<ExecutorImpl> ExecutorImplFactory::create_executor_impl(
CausalLM* model,
const ModelArgs& args,
const torch::Device& device,
const runtime::Options& options) {
std::string backend = "base";
if (FLAGS_enable_graph) {
backend = Device::type_str();
LOG(INFO) << "Creating Graph Executor for " << backend << " device";
}
auto it = creators_.find(backend);
if (it == creators_.end()) {
throw std::runtime_error("No valid graph backend found: " + backend);
}
return it->second(model, args, device, options);
}
有几种Executor。
REGISTER_EXECUTOR(“npu”, AclGraphExecutorImpl);
华为昇腾npu,使用AclGraph。
REGISTER_EXECUTOR(“mlu”, MluGraphExecutorImpl);
摩尔线程mlu,使用MluGraph
REGISTER_EXECUTOR(“base”, BaseExecutorImpl);
BaseExecutorImpl::run
torch::Tensor BaseExecutorImpl::run(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& params) {
COUNTER_INC(num_model_execution_total_eager);
return model_->forward(tokens, positions, kv_caches, params);
}
model_的创建
bool LLMWorkerImpl::init_model(ModelContext& context) {
CHECK(model_ == nullptr) << "Model is already initialized.";
// Try to create a causal LM model
model_ = create_llm_model(context);
// Dont find model in causal models
CHECK(model_ != nullptr) << "Failed to create model.";
model_executor_ = std::make_unique<Executor>(
model_.get(), context.get_model_args(), device_, options_);
if (FLAGS_enable_eplb) {
eplb_executor_ = std::make_unique<EplbExecutor>(model_.get(), device_);
}
if (FLAGS_enable_beam_search_kernel) {
beam_searcher_ = std::make_unique<BeamSearcher>();
}
return true;
}
create_llm_model
std::unique_ptr<CausalLM> create_llm_model(const ModelContext& context) {
// get the factory function for the model type from model registry
auto factory = ModelRegistry::get_causallm_factory(
context.get_model_args().model_type());
if (factory) {
return factory(context);
}
LOG(ERROR) << "Unsupported model type: "
<< context.get_model_args().model_type();
return nullptr;
}
CausalLMFactory ModelRegistry::get_causallm_factory(const std::string& name) {
ModelRegistry* instance = get_instance();
return instance->model_registry_[name].causal_lm_factory;
}
// Macro to register a model with the ModelRegistry
#define REGISTER_CAUSAL_MODEL_WITH_VARNAME(VarName, ModelType, ModelClass) \
const bool VarName##_registered = []() { \
ModelRegistry::register_causallm_factory( \
#ModelType, [](const ModelContext& context) { \
ModelClass model(context); \
model->eval(); \
return std::make_unique<xllm::CausalLMImpl<ModelClass>>( \
std::move(model), context.get_tensor_options()); \
}); \
return true; \
}()
#define REGISTER_CAUSAL_MODEL(ModelType, ModelClass) \
REGISTER_CAUSAL_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass)
models.h包含的模型
https://github.com/jd-opensource/xllm/blob/main/xllm/models/models.h
#if defined(USE_NPU)
#include "dit/pipeline_flux.h" // IWYU pragma: keep
#include "dit/pipeline_flux_control.h" // IWYU pragma: keep
#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep
#include "llm/npu/deepseek_mtp.h" // IWYU pragma: keep
#include "llm/npu/deepseek_v2.h" // IWYU pragma: keep
#include "llm/npu/deepseek_v3.h" // IWYU pragma: keep
#include "llm/npu/deepseek_v32.h" // IWYU pragma: keep
#include "llm/npu/glm4.h" // IWYU pragma: keep
#include "llm/npu/glm4_moe.h" // IWYU pragma: keep
#include "llm/npu/glm4_moe_mtp.h" // IWYU pragma: keep
#include "llm/npu/kimi_k2.h" // IWYU pragma: keep
#include "llm/npu/llama.h" // IWYU pragma: keep
#include "llm/npu/llama3.h" // IWYU pragma: keep
#include "llm/npu/qwen2.h" // IWYU pragma: keep
#include "llm/npu/qwen3.h" // IWYU pragma: keep
#include "llm/npu/qwen3_embedding.h" // IWYU pragma: keep
#include "llm/npu/qwen3_moe.h" // IWYU pragma: keep
#include "vlm/npu/glm4v.h" // IWYU pragma: keep
#include "vlm/npu/glm4v_moe.h" // IWYU pragma: keep
#include "vlm/npu/minicpmv.h" // IWYU pragma: keep
#include "vlm/npu/qwen2_5_vl.h" // IWYU pragma: keep
#include "vlm/npu/qwen2_5_vl_mm_embedding.h" // IWYU pragma: keep
#include "vlm/npu/qwen2_vl.h" // IWYU pragma: keep
#include "vlm/npu/qwen2_vl_embedding.h" // IWYU pragma: keep
#include "vlm/npu/qwen3_vl.h" // IWYU pragma: keep
#include "vlm/npu/qwen3_vl_mm_embedding.h" // IWYU pragma: keep
#include "vlm/npu/qwen3_vl_moe.h" // IWYU pragma: keep
#elif defined(USE_MLU)
#include "llm/deepseek_mtp.h" // IWYU pragma: keep
#include "llm/deepseek_v2.h" // IWYU pragma: keep
#include "llm/deepseek_v3.h" // IWYU pragma: keep
#include "llm/deepseek_v32.h" // IWYU pragma: keep
#include "llm/qwen2.h" // IWYU pragma: keep
#include "llm/qwen3.h" // IWYU pragma: keep
#include "llm/qwen3_moe.h" // IWYU pragma: keep
#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl_embedding.h" // IWYU pragma: keep
#include "vlm/qwen3_vl.h" // IWYU pragma: keep
#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep
#elif defined(USE_ILU)
#include "llm/qwen2.h" // IWYU pragma: keep
#include "llm/qwen3.h" // IWYU pragma: keep
#include "llm/qwen3_moe.h" // IWYU pragma: keep
#else
#include "llm/qwen2.h" // IWYU pragma: keep
#include "llm/qwen3.h" // IWYU pragma: keep
#include "llm/qwen3_moe.h" // IWYU pragma: keep
#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl_embedding.h" // IWYU pragma: keep
#include "vlm/qwen3_vl.h" // IWYU pragma: keep
#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep
#endif
Qwen3MoeForCausalLM
https://github.com/jd-opensource/xllm/blob/main/xllm/models/llm/qwen3_moe.h
TORCH_MODULE(Qwen3MoeModel);
class Qwen3MoeForCausalLMImpl : public LlmForCausalLMImplBase<Qwen3MoeModel> {
public:
Qwen3MoeForCausalLMImpl(const ModelContext& context)
: LlmForCausalLMImplBase<Qwen3MoeModel>(context) {}
};
TORCH_MODULE(Qwen3MoeForCausalLM);
// register the causal model
REGISTER_CAUSAL_MODEL(qwen3_moe, Qwen3MoeForCausalLM);
forward函数中构造attn_metadata。
auto attn_metadata = layer::AttentionMetadata::build(modified_input_params)
Qwen3MoeDecoderLayer
https://github.com/jd-opensource/xllm/blob/main/xllm/core/layers/qwen3_moe_decoder_layer.h
class Qwen3MoeDecoderLayerImpl : public torch::nn::Module {
public:
explicit Qwen3MoeDecoderLayerImpl(const ModelContext& context,
int32_t layer_id);
void load_state_dict(const StateDict& state_dict);
torch::Tensor forward(torch::Tensor& x,
std::optional<torch::Tensor>& residual,
torch::Tensor& positions,
const AttentionMetadata& attn_metadata,
KVCache& kv_cache,
const ModelInputParams& input_params);
private:
Qwen2Attention attention_{nullptr};
DenseMLP mlp_{nullptr};
FusedMoE moe_mlp_{nullptr};
RMSNorm input_norm_{nullptr};
RMSNorm post_norm_{nullptr};
};
TORCH_MODULE(Qwen3MoeDecoderLayer);
Qwen2Attention
https://gitee.com/mirrors/xllm/blob/main/xllm/core/layers/common/qwen2_attention.h
class Qwen2AttentionImpl : public torch::nn::Module {
public:
Qwen2AttentionImpl() = default;
Qwen2AttentionImpl(const ModelContext& context);
torch::Tensor forward(const torch::Tensor& positions,
const torch::Tensor& hidden_states,
const AttentionMetadata& attn_metadata,
KVCache& kv_cache);
void load_state_dict(const StateDict& state_dict);
private:
int64_t num_heads_;
int64_t num_kv_heads_;
int64_t num_kv_head_replicas_;
int64_t head_dim_;
int64_t q_size_;
int64_t kv_size_;
float scaling_;
bool is_qwen3_style_;
QKVParallelLinear qkv_proj_{nullptr};
RowParallelLinear o_proj_{nullptr};
RMSNorm q_norm_{nullptr};
RMSNorm k_norm_{nullptr};
Attention attn_{nullptr};
MRotaryEmbedding rotary_emb_{nullptr};
};
TORCH_MODULE(Qwen2Attention);
Qwen2AttentionImpl::forward
torch::Tensor Qwen2AttentionImpl::forward(
const torch::Tensor& positions,
const torch::Tensor& hidden_states,
const AttentionMetadata& attn_metadata,
KVCache& kv_cache) {
// 1. qkv projection
auto qkv = qkv_proj_->forward(hidden_states);
auto q = qkv.slice(/*dim=*/-1, 0, q_size_);
auto k = qkv.slice(/*dim=*/-1, q_size_, q_size_ + kv_size_);
auto v = qkv.slice(/*dim=*/-1, q_size_ + kv_size_, q_size_ + 2 * kv_size_);
const int64_t T = q.size(0);
if (is_qwen3_style_) {
// 2. q-norm
q = std::get<0>(q_norm_->forward(q));
// 3. k-norm
k = std::get<0>(k_norm_->forward(k));
}
// 4. rope
rotary_emb_->forward(q, k, positions, attn_metadata);
q = q.view({T, q_size_});
k = k.view({T, kv_size_});
// 5. store k/v cache and do attention
auto out = std::get<0>(attn_->forward(attn_metadata, q, k, v, kv_cache));
// 6. output projection
return o_proj_->forward(out);
}
Attention
xllm-main/xllm/core/layers/common/attention.h
#if defined(USE_MLU)
#include "layers/mlu/attention.h"
#elif defined(USE_NPU)
#include "layers/npu_torch/attention.h"
#elif defined(USE_CUDA)
#include "layers/cuda/attention.h"
#elif defined(USE_ILU)
#include "layers/ilu/attention.h"
#endif
分析cuda平台的attention,xllm-main/xllm/core/layers/cuda/attention.h
class AttentionImpl : public torch::nn::Module {
public:
AttentionImpl() = default;
AttentionImpl(int num_heads,
int head_size,
float scale,
int num_kv_heads,
int sliding_window);
std::tuple<torch::Tensor, std::optional<torch::Tensor>> forward(
const AttentionMetadata& attn_metadata,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& value,
KVCache& kv_cache);
private:
int num_heads_;
int head_size_;
float scale_;
int num_kv_heads_;
int sliding_window_;
};
TORCH_MODULE(Attention);
xllm-main/xllm/core/layers/cuda/attention.cpp
std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
const AttentionMetadata& attn_metadata,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& value,
KVCache& kv_cache) {
// maybe we need to update shared attn state before execute attention,
// currently we update flashinfer step_wise_attn_state_ at layer 0.
bool causal = attn_metadata.is_prefill || attn_metadata.is_chunked_prefill;
flashinfer::update_plan_info(
attn_metadata.plan_info,
causal ? xllm::kernel::cuda::determine_attention_backend(
/*pos_encoding_mode=*/0,
/*use_fp16_qk_reduction=*/false,
/*use_custom_mask=*/false)
: "fa2",
attn_metadata,
query.scalar_type(),
key.scalar_type(),
output.scalar_type(),
head_size_,
head_size_,
num_heads_,
num_kv_heads_,
/*block_size*/ k_cache.size(1),
/*window_size_left*/ sliding_window_,
/*enable_cuda_graph*/ false,
/*causal*/ causal,
/*use_tensor_core*/ true);
// TODO: support chunked prefill
CHECK(!attn_metadata.is_chunked_prefill)
<< "chunked prefill is not supported";
if (attn_metadata.is_prefill) {
attention_params.key = key;
attention_params.value = value;
xllm::kernel::batch_prefill(attention_params);
} else {
attention_params.query = query;
attention_params.output = output;
attention_params.k_cache = k_cache;
attention_params.v_cache = v_cache;
attention_params.kv_seq_lens = attn_metadata.kv_seq_lens;
attention_params.paged_kv_indptr = attn_metadata.paged_kv_indptr;
attention_params.paged_kv_indices = attn_metadata.paged_kv_indices;
attention_params.paged_kv_last_page_len =
attn_metadata.paged_kv_last_page_len;
xllm::kernel::batch_decode(attention_params);
}
output = output.view({-1, num_heads_ * head_size_});
return {output, output_lse};
}
cuda上使用flashinfo提供的page attention核函数。
flashinfer::update_plan_info设置调用算子的uri。
std::string get_batch_prefill_uri(const std::string& backend,
torch::ScalarType dtype_q,
torch::ScalarType dtype_kv,
torch::ScalarType dtype_o,
torch::ScalarType dtype_idx,
int64_t head_dim_qk,
int64_t head_dim_vo,
int64_t pos_encoding_mode,
bool use_sliding_window,
bool use_logits_soft_cap,
bool use_fp16_qk_reduction) {
std::ostringstream oss;
oss << "batch_prefill_with_kv_cache_"
<< "dtype_q_" << filename_safe_dtype_map.at(dtype_q) << "_"
<< "dtype_kv_" << filename_safe_dtype_map.at(dtype_kv) << "_"
<< "dtype_o_" << filename_safe_dtype_map.at(dtype_o) << "_"
<< "dtype_idx_" << filename_safe_dtype_map.at(dtype_idx) << "_"
<< "head_dim_qk_" << head_dim_qk << "_"
<< "head_dim_vo_" << head_dim_vo << "_"
<< "posenc_" << pos_encoding_mode << "_"
<< "use_swa_" << (use_sliding_window ? "True" : "False") << "_"
<< "use_logits_cap_" << (use_logits_soft_cap ? "True" : "False") << "_"
<< "f16qk_" << (use_fp16_qk_reduction ? "True" : "False");
if (backend == "fa3") oss << "_sm90";
return oss.str();
}
算子层api
batch_decode
xllm-main/xllm/core/kernels/ops_api.cpp
void batch_decode(AttentionParams& params) {
#if defined(USE_MLU)
mlu::batch_decode(params.query,
params.k_cache,
params.output,
params.block_table.value(),
params.kv_seq_lens,
params.v_cache,
params.output_lse,
params.q_quant_scale,
params.k_cache_quant_scale,
params.v_cache_quant_scale,
params.out_quant_scale,
params.alibi_slope,
params.mask,
params.compute_dtype,
params.max_seq_len,
params.window_size_left,
params.window_size_right,
params.scale,
params.return_lse,
params.kv_cache_quant_bit_size);
#elif defined(USE_NPU)
npu::batch_decode(params.query,
params.k_cache,
params.v_cache.value_or(torch::Tensor()),
params.scale,
params.block_table.value(),
params.seq_lens,
params.output);
#elif defined(USE_CUDA)
cuda::batch_decode(params.uri,
params.plan_info,
params.float_workspace_buffer,
params.int_workspace_buffer,
params.page_locked_int_workspace_buffer,
params.query,
params.k_cache,
params.v_cache.value_or(torch::Tensor()),
params.paged_kv_indptr,
params.paged_kv_indices,
params.paged_kv_last_page_len,
params.window_size_left,
params.scale,
params.output,
params.output_lse,
params.enable_cuda_graph,
params.use_tensor_core,
params.kv_seq_lens);
#elif defined(USE_ILU)
ilu::batch_decode(params.query,
params.k_cache,
params.output,
params.block_table.value(),
params.kv_seq_lens,
params.v_cache,
params.output_lse,
params.q_quant_scale,
params.k_cache_quant_scale,
params.v_cache_quant_scale,
params.out_quant_scale,
params.alibi_slope,
params.mask,
params.compute_dtype,
params.max_seq_len,
params.window_size_left,
params.window_size_right,
params.scale,
params.return_lse,
params.is_causal,
params.kv_cache_quant_bit_size);
#else
NOT_IMPLEMENTED();
#endif
cuda::batch_decode
xllm-main/xllm/core/kernels/cuda/batch_decode.cpp
void batch_decode(const std::string& uri,
torch::Tensor plan_info,
torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer,
torch::Tensor page_locked_int_workspace_buffer,
torch::Tensor query,
torch::Tensor k_cache,
torch::Tensor v_cache,
torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len,
int64_t window_left,
double sm_scale,
torch::Tensor output,
std::optional<torch::Tensor>& output_lse,
bool enable_cuda_graph,
bool use_tensor_core,
torch::Tensor kv_seq_lens) {
if (use_tensor_core) {
const int64_t batch_size = paged_kv_last_page_len.size(0);
torch::Tensor qo_indptr_host =
get_cache_buffer(batch_size + 1, torch::kCPU);
torch::Tensor qo_indptr = qo_indptr_host.to(torch::kCUDA);
FunctionFactory::get_instance().fa2_prefill_paged_run_func(uri).call(
float_workspace_buffer,
int_workspace_buffer,
plan_info,
query,
k_cache,
v_cache,
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
output,
output_lse,
/*mask_mode_code=*/0, // NON_CAUSAL
/*kv_layout_code=*/0, // NHD layout
window_left,
support_pdl(),
/*maybe_custom_mask=*/std::optional<torch::Tensor>(),
/*maybe_mask_indptr=*/std::optional<torch::Tensor>(),
/*maybe_alibi_slopes=*/std::optional<torch::Tensor>(),
/*maybe_prefix_len_ptr=*/std::optional<torch::Tensor>(),
/*maybe_token_pos_in_items_ptr=*/std::optional<torch::Tensor>(),
/*maybe_max_item_len_ptr=*/std::optional<torch::Tensor>(),
/*logits_soft_cap=*/0.0,
sm_scale,
/*rope_rcp_scale=*/1.0,
/*rope_rcp_theta=*/1.0 / 10000.0,
/*token_pos_in_items_len=*/0);
} else {
FunctionFactory::get_instance().decode_run_func(uri).call(
float_workspace_buffer,
int_workspace_buffer,
plan_info,
query,
k_cache,
v_cache,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
output,
output_lse,
/*kv_layout_code=*/0, // NHD layout
window_left,
support_pdl(),
/*maybe_alibi_slopes=*/std::optional<torch::Tensor>(),
/*logits_soft_cap=*/0.0,
sm_scale,
/*rope_rcp_scale=*/1.0,
/*rope_rcp_theta=*/1.0 / 10000.0);
}
}
fa2_prefill_paged_run_func
FA2_PREFILL_PAGED_RUN_FUNC_TYPE fa2_prefill_paged_run_func(
const std::string& uri) {
static std::optional<FA2_PREFILL_PAGED_RUN_FUNC_TYPE> f;
static std::unique_ptr<torch::DynamicLibrary> lib;
if (f.has_value()) {
return f.value();
}
static std::once_flag flag;
std::call_once(flag, [&uri]() {
lib = std::make_unique<torch::DynamicLibrary>(
path_to_uri_so_lib(uri).c_str(), nullptr, true);
std::string run_schema_name = uri + "::paged_run";
f = torch::Dispatcher::singleton()
.findSchemaOrThrow(run_schema_name.c_str(), "")
.typed<void(torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
std::optional<torch::Tensor>,
int64_t,
int64_t,
int64_t,
bool,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
double,
double,
double,
double,
int64_t)>();
});
return f.value();
}
torch::DynamicLibrary加载动态库。动态库路径解析
std::string path_to_uri_so_lib(const std::string& uri) {
return util::get_string_env("FLASHINFER_OPS_PATH") + "/" + uri + "/" + uri +
".so";
}
更多推荐
所有评论(0)