LLMEngine::step

ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
  // update dp related global paramters and then execute model
  for (auto worker_rank = 0; worker_rank < worker_clients_num_; ++worker_rank) {
    auto dp_rank = worker_rank / dp_local_tp_size_;
    futures.emplace_back(
        worker_clients_[worker_rank]->step_async(raw_forward_inputs[dp_rank]));
  }
}

 step_async

folly::SemiFuture<std::optional<RawForwardOutput>> RemoteWorker::step_async(
    const RawForwardInput& inputs) {
  folly::Promise<std::optional<RawForwardOutput>> promise;
  auto future = promise.getSemiFuture();
  threadpool_.schedule([this,
                        inputs = std::move(inputs),
                        promise = std::move(promise)]() mutable {
    channel_->execute_model_async({inputs}, promise);
  });

  return future;
}

 channel_->execute_model_async

void CommChannel::execute_model_async(
    const std::vector<RawForwardInput>& inputs,
    folly::Promise<std::optional<RawForwardOutput>>& promise) {
  execute_model_with_brpc(inputs, promise);
}

bool CommChannel::execute_model_with_brpc(
    const std::vector<RawForwardInput>& inputs,
    folly::Promise<std::optional<RawForwardOutput>>& promise) {
  // convert to proto::ForwardInput
  proto::ForwardInput pb_forward_input;
  forward_input_to_proto(inputs[0], &pb_forward_input);

  // call ExecuteModel with callback
  auto done = new ExecuteModelClosure();
  done->promise = std::move(promise);
  stub_->ExecuteModel(&done->cntl, &pb_forward_input, &done->pb_output, done);
  return true;
}

 stub_->ExecuteModel是rpc调用。

远端的ExecuteModel

void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
                                 const proto::ForwardInput* pb_forward_input,
                                 proto::ForwardOutput* pb_forward_output,
                                 ::google::protobuf::Closure* done) {
  threadpool_->schedule(
      [this, controller, pb_forward_input, pb_forward_output, done]() mutable {
        brpc::ClosureGuard done_guard(done);
        // convert proto::ForwardInput to ForwardInput

        Timer timer;
        ForwardInput forward_input;
        proto_to_forward_input(
            pb_forward_input, forward_input, options_.num_decoding_tokens());

        // model output
        torch::Tensor next_tokens;
        torch::Tensor logprobs;
        torch::Tensor top_tokens;
        torch::Tensor top_logprobs;
        torch::Tensor embeddings;
        torch::Tensor expert_load_data;
        int32_t prepared_layer_id = -1;
        // beam search kernel output
        torch::Tensor src_seq_idxes;
        torch::Tensor out_tokens;
        torch::Tensor out_logprobs;

        step(forward_input,
             next_tokens,
             logprobs,
             top_tokens,
             top_logprobs,
             embeddings,
             expert_load_data,
             prepared_layer_id,
             src_seq_idxes,
             out_tokens,
             out_logprobs);
        // convert to proto output
        forward_output_to_proto(next_tokens,
                                logprobs,
                                top_tokens,
                                top_logprobs,
                                embeddings,
                                expert_load_data,
                                prepared_layer_id,
                                src_seq_idxes,
                                out_tokens,
                                out_logprobs,
                                pb_forward_output);
        COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
      });
}

void WorkerService::step(ForwardInput& fwd_input,
                         torch::Tensor& next_tokens,
                         torch::Tensor& logprobs,
                         torch::Tensor& top_tokens,
                         torch::Tensor& top_logprobs,
                         torch::Tensor& embeddings,
                         torch::Tensor& expert_load_data,
                         int32_t& prepared_layer_id,
                         torch::Tensor& src_seq_idxes,
                         torch::Tensor& out_tokens,
                         torch::Tensor& out_logprobs) {
  // execute model
  auto future = worker_->step_async(fwd_input);

  if (!options_.enable_schedule_overlap()) {
      
  }
}

 worker_->step_async

folly::SemiFuture<std::optional<ForwardOutput>> Worker::step_async(
    const ForwardInput& inputs) {
  return impl_->step_async(inputs);
}

folly::SemiFuture<std::optional<ForwardOutput>> WorkerImpl::step_async(
    const ForwardInput& input) {
  ForwardInput input_on_device;

  prepare_work_before_execute(input, input_on_device);

  folly::Promise<std::optional<ForwardOutput>> promise;
  auto future = promise.getSemiFuture();
  threadpool_.schedule([this,
                        input = std::move(input_on_device),
                        promise = std::move(promise)]() mutable {
    if (hierarchy_kv_cache_transfer_ != nullptr) {
      hierarchy_kv_cache_transfer_->set_layer_synchronizer(input.input_params);
    }

    // run the model on the given input in working thread
    if (!enable_schedule_overlap()) {
      const auto output = this->step(input);
      promise.setValue(output);
    } else {
      if (last_step_output_valid_ && !input.input_params.empty_kv_cache) {
        // replace step i model input with true output of step i-1
        input = update_input_by_last_step_output(input);
      }

      const auto output = this->step(input);
      if (output.has_value()) {
        if (is_driver() || FLAGS_enable_eplb) {
          std::unique_lock<std::mutex> lock(mtx_);
          cv_.wait(lock, [this] { return !is_recorded_; });
          update_last_step_output(output);
          is_recorded_ = true;
          cv_.notify_one();
        } else {
          update_last_step_output(output);
        }
      } else {
        if (is_driver() || FLAGS_enable_eplb) {
          std::unique_lock<std::mutex> lock(mtx_);
          cv_.wait(lock, [this] { return !is_recorded_; });
          last_step_output_valid_ = false;
          is_recorded_ = true;
          cv_.notify_one();
        } else {
          last_step_output_valid_ = false;
        }
      }
      promise.setValue(output);
    }
  });
  return future;
}

 LLMWorkerImpl::step

std::optional<ForwardOutput> LLMWorkerImpl::step(const ForwardInput& input) {
  Timer timer;
  auto& sampling_params = input.sampling_params;
  
  // temporarily use [0], will be adapted in next pr
  // call model executor forward to get hidden states
  auto hidden_states = model_executor_->forward(
      input.token_ids, input.positions, kv_caches_, input.input_params);
  if (!hidden_states.defined()) {
    return std::nullopt;
  }

  torch::Tensor logits;
  if (sampling_params.selected_token_idxes.defined()) {
    logits =
        model_->logits(hidden_states, sampling_params.selected_token_idxes);
  }

  ForwardOutput output;
  
  // driver prepare model output
  SampleOutput sample_output;
  if (sampling_params.selected_token_idxes.defined()) {
    sample_output = sampler_->forward(logits, sampling_params);
    output.logits = logits;

    // beam search kernel
    BeamSearchOutput beam_search_output;
    if (sampling_params.use_beam_search && input.acc_logprob.defined() &&
        input.acc_logprob.numel() > 0) {
      beam_search_output = beam_searcher_->forward(input.acc_logprob,
                                                   sample_output.top_tokens,
                                                   sample_output.top_logprobs);
    }

    // set sample output to output
    output.sample_output = sample_output;
    // carry over the sampling params
    output.do_sample = sampling_params.do_sample;
    output.logprobs = sampling_params.logprobs;
    output.max_top_logprobs = sampling_params.max_top_logprobs;
    // set beam search output to output
    output.beam_search_output = beam_search_output;
  }

  if (options_.enable_speculative_decode()) {
    if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) {
      output.sample_output.embeddings = hidden_states;
    } else if (sampling_params.selected_token_idxes.defined()) {
      auto embeddings = hidden_states.index_select(
          /*dim=*/0, sampling_params.selected_token_idxes);
      output.sample_output.embeddings = embeddings;
    }
  }
  
  return output;
}

 model_executor_->forward

torch::Tensor Executor::forward(const torch::Tensor& tokens,
                                const torch::Tensor& positions,
                                std::vector<KVCache>& kv_caches,
                                const ModelInputParams& params) {
  return impl_->run(tokens, positions, kv_caches, params);
}

Executor::Executor(CausalLM* model,
                   const ModelArgs& args,
                   const torch::Device& device,
                   const runtime::Options& options) {
  impl_ = ExecutorImplFactory::get_instance().create_executor_impl(
      model, args, device, options);
}

std::unique_ptr<ExecutorImpl> ExecutorImplFactory::create_executor_impl(
    CausalLM* model,
    const ModelArgs& args,
    const torch::Device& device,
    const runtime::Options& options) {
  std::string backend = "base";
  if (FLAGS_enable_graph) {
    backend = Device::type_str();
    LOG(INFO) << "Creating Graph Executor for " << backend << " device";
  }

  auto it = creators_.find(backend);
  if (it == creators_.end()) {
    throw std::runtime_error("No valid graph backend found: " + backend);
  }

  return it->second(model, args, device, options);
}

 有几种Executor。
REGISTER_EXECUTOR(“npu”, AclGraphExecutorImpl);
华为昇腾npu,使用AclGraph。
REGISTER_EXECUTOR(“mlu”, MluGraphExecutorImpl);
摩尔线程mlu,使用MluGraph
REGISTER_EXECUTOR(“base”, BaseExecutorImpl);
 BaseExecutorImpl::run

torch::Tensor BaseExecutorImpl::run(const torch::Tensor& tokens,
                                    const torch::Tensor& positions,
                                    std::vector<KVCache>& kv_caches,
                                    const ModelInputParams& params) {
  COUNTER_INC(num_model_execution_total_eager);
  return model_->forward(tokens, positions, kv_caches, params);
}

model_的创建

bool LLMWorkerImpl::init_model(ModelContext& context) {
  CHECK(model_ == nullptr) << "Model is already initialized.";

  // Try to create a causal LM model
  model_ = create_llm_model(context);

  // Dont find model in causal models
  CHECK(model_ != nullptr) << "Failed to create model.";
  model_executor_ = std::make_unique<Executor>(
      model_.get(), context.get_model_args(), device_, options_);

  if (FLAGS_enable_eplb) {
    eplb_executor_ = std::make_unique<EplbExecutor>(model_.get(), device_);
  }

  if (FLAGS_enable_beam_search_kernel) {
    beam_searcher_ = std::make_unique<BeamSearcher>();
  }
  return true;
}

 create_llm_model

std::unique_ptr<CausalLM> create_llm_model(const ModelContext& context) {
  // get the factory function for the model type from model registry
  auto factory = ModelRegistry::get_causallm_factory(
      context.get_model_args().model_type());
  if (factory) {
    return factory(context);
  }

  LOG(ERROR) << "Unsupported model type: "
             << context.get_model_args().model_type();

  return nullptr;
}

CausalLMFactory ModelRegistry::get_causallm_factory(const std::string& name) {
  ModelRegistry* instance = get_instance();

  return instance->model_registry_[name].causal_lm_factory;
}
// Macro to register a model with the ModelRegistry
#define REGISTER_CAUSAL_MODEL_WITH_VARNAME(VarName, ModelType, ModelClass) \
  const bool VarName##_registered = []() {                                 \
    ModelRegistry::register_causallm_factory(                              \
        #ModelType, [](const ModelContext& context) {                      \
          ModelClass model(context);                                       \
          model->eval();                                                   \
          return std::make_unique<xllm::CausalLMImpl<ModelClass>>(         \
              std::move(model), context.get_tensor_options());             \
        });                                                                \
    return true;                                                           \
  }()

#define REGISTER_CAUSAL_MODEL(ModelType, ModelClass) \
  REGISTER_CAUSAL_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass)

models.h包含的模型

https://github.com/jd-opensource/xllm/blob/main/xllm/models/models.h

#if defined(USE_NPU)
#include "dit/pipeline_flux.h"                // IWYU pragma: keep
#include "dit/pipeline_flux_control.h"        // IWYU pragma: keep
#include "dit/pipeline_flux_fill.h"           // IWYU pragma: keep
#include "llm/npu/deepseek_mtp.h"             // IWYU pragma: keep
#include "llm/npu/deepseek_v2.h"              // IWYU pragma: keep
#include "llm/npu/deepseek_v3.h"              // IWYU pragma: keep
#include "llm/npu/deepseek_v32.h"             // IWYU pragma: keep
#include "llm/npu/glm4.h"                     // IWYU pragma: keep
#include "llm/npu/glm4_moe.h"                 // IWYU pragma: keep
#include "llm/npu/glm4_moe_mtp.h"             // IWYU pragma: keep
#include "llm/npu/kimi_k2.h"                  // IWYU pragma: keep
#include "llm/npu/llama.h"                    // IWYU pragma: keep
#include "llm/npu/llama3.h"                   // IWYU pragma: keep
#include "llm/npu/qwen2.h"                    // IWYU pragma: keep
#include "llm/npu/qwen3.h"                    // IWYU pragma: keep
#include "llm/npu/qwen3_embedding.h"          // IWYU pragma: keep
#include "llm/npu/qwen3_moe.h"                // IWYU pragma: keep
#include "vlm/npu/glm4v.h"                    // IWYU pragma: keep
#include "vlm/npu/glm4v_moe.h"                // IWYU pragma: keep
#include "vlm/npu/minicpmv.h"                 // IWYU pragma: keep
#include "vlm/npu/qwen2_5_vl.h"               // IWYU pragma: keep
#include "vlm/npu/qwen2_5_vl_mm_embedding.h"  // IWYU pragma: keep
#include "vlm/npu/qwen2_vl.h"                 // IWYU pragma: keep
#include "vlm/npu/qwen2_vl_embedding.h"       // IWYU pragma: keep
#include "vlm/npu/qwen3_vl.h"                 // IWYU pragma: keep
#include "vlm/npu/qwen3_vl_mm_embedding.h"    // IWYU pragma: keep
#include "vlm/npu/qwen3_vl_moe.h"             // IWYU pragma: keep
#elif defined(USE_MLU)
#include "llm/deepseek_mtp.h"        // IWYU pragma: keep
#include "llm/deepseek_v2.h"         // IWYU pragma: keep
#include "llm/deepseek_v3.h"         // IWYU pragma: keep
#include "llm/deepseek_v32.h"        // IWYU pragma: keep
#include "llm/qwen2.h"               // IWYU pragma: keep
#include "llm/qwen3.h"               // IWYU pragma: keep
#include "llm/qwen3_moe.h"           // IWYU pragma: keep
#include "vlm/qwen2_5_vl.h"          // IWYU pragma: keep
#include "vlm/qwen2_vl.h"            // IWYU pragma: keep
#include "vlm/qwen2_vl_embedding.h"  // IWYU pragma: keep
#include "vlm/qwen3_vl.h"            // IWYU pragma: keep
#include "vlm/qwen3_vl_moe.h"        // IWYU pragma: keep
#elif defined(USE_ILU)
#include "llm/qwen2.h"      // IWYU pragma: keep
#include "llm/qwen3.h"      // IWYU pragma: keep
#include "llm/qwen3_moe.h"  // IWYU pragma: keep
#else
#include "llm/qwen2.h"               // IWYU pragma: keep
#include "llm/qwen3.h"               // IWYU pragma: keep
#include "llm/qwen3_moe.h"           // IWYU pragma: keep
#include "vlm/qwen2_5_vl.h"          // IWYU pragma: keep
#include "vlm/qwen2_vl.h"            // IWYU pragma: keep
#include "vlm/qwen2_vl_embedding.h"  // IWYU pragma: keep
#include "vlm/qwen3_vl.h"            // IWYU pragma: keep
#include "vlm/qwen3_vl_moe.h"        // IWYU pragma: keep
#endif

Qwen3MoeForCausalLM

https://github.com/jd-opensource/xllm/blob/main/xllm/models/llm/qwen3_moe.h

TORCH_MODULE(Qwen3MoeModel);

class Qwen3MoeForCausalLMImpl : public LlmForCausalLMImplBase<Qwen3MoeModel> {
 public:
  Qwen3MoeForCausalLMImpl(const ModelContext& context)
      : LlmForCausalLMImplBase<Qwen3MoeModel>(context) {}
};
TORCH_MODULE(Qwen3MoeForCausalLM);

// register the causal model
REGISTER_CAUSAL_MODEL(qwen3_moe, Qwen3MoeForCausalLM);

forward函数中构造attn_metadata。
auto attn_metadata = layer::AttentionMetadata::build(modified_input_params)

Qwen3MoeDecoderLayer

https://github.com/jd-opensource/xllm/blob/main/xllm/core/layers/qwen3_moe_decoder_layer.h

class Qwen3MoeDecoderLayerImpl : public torch::nn::Module {
 public:
  explicit Qwen3MoeDecoderLayerImpl(const ModelContext& context,
                                    int32_t layer_id);

  void load_state_dict(const StateDict& state_dict);

  torch::Tensor forward(torch::Tensor& x,
                        std::optional<torch::Tensor>& residual,
                        torch::Tensor& positions,
                        const AttentionMetadata& attn_metadata,
                        KVCache& kv_cache,
                        const ModelInputParams& input_params);

 private:
  Qwen2Attention attention_{nullptr};
  DenseMLP mlp_{nullptr};
  FusedMoE moe_mlp_{nullptr};
  RMSNorm input_norm_{nullptr};
  RMSNorm post_norm_{nullptr};
};
TORCH_MODULE(Qwen3MoeDecoderLayer);

Qwen2Attention

https://gitee.com/mirrors/xllm/blob/main/xllm/core/layers/common/qwen2_attention.h

class Qwen2AttentionImpl : public torch::nn::Module {
 public:
  Qwen2AttentionImpl() = default;
  Qwen2AttentionImpl(const ModelContext& context);

  torch::Tensor forward(const torch::Tensor& positions,
                        const torch::Tensor& hidden_states,
                        const AttentionMetadata& attn_metadata,
                        KVCache& kv_cache);

  void load_state_dict(const StateDict& state_dict);

 private:
  int64_t num_heads_;
  int64_t num_kv_heads_;
  int64_t num_kv_head_replicas_;
  int64_t head_dim_;
  int64_t q_size_;
  int64_t kv_size_;
  float scaling_;
  bool is_qwen3_style_;

  QKVParallelLinear qkv_proj_{nullptr};
  RowParallelLinear o_proj_{nullptr};
  RMSNorm q_norm_{nullptr};
  RMSNorm k_norm_{nullptr};
  Attention attn_{nullptr};
  MRotaryEmbedding rotary_emb_{nullptr};
};
TORCH_MODULE(Qwen2Attention);

 Qwen2AttentionImpl::forward

torch::Tensor Qwen2AttentionImpl::forward(
    const torch::Tensor& positions,
    const torch::Tensor& hidden_states,
    const AttentionMetadata& attn_metadata,
    KVCache& kv_cache) {
  // 1. qkv projection
  auto qkv = qkv_proj_->forward(hidden_states);

  auto q = qkv.slice(/*dim=*/-1, 0, q_size_);
  auto k = qkv.slice(/*dim=*/-1, q_size_, q_size_ + kv_size_);
  auto v = qkv.slice(/*dim=*/-1, q_size_ + kv_size_, q_size_ + 2 * kv_size_);

  const int64_t T = q.size(0);

  if (is_qwen3_style_) {
    // 2. q-norm
    q = std::get<0>(q_norm_->forward(q));

    // 3. k-norm
    k = std::get<0>(k_norm_->forward(k));
  }

  // 4. rope
  rotary_emb_->forward(q, k, positions, attn_metadata);
  q = q.view({T, q_size_});
  k = k.view({T, kv_size_});

  // 5. store k/v cache and do attention
  auto out = std::get<0>(attn_->forward(attn_metadata, q, k, v, kv_cache));

  // 6. output projection
  return o_proj_->forward(out);
}

Attention

xllm-main/xllm/core/layers/common/attention.h

#if defined(USE_MLU)
#include "layers/mlu/attention.h"
#elif defined(USE_NPU)
#include "layers/npu_torch/attention.h"
#elif defined(USE_CUDA)
#include "layers/cuda/attention.h"
#elif defined(USE_ILU)
#include "layers/ilu/attention.h"
#endif

 分析cuda平台的attention,xllm-main/xllm/core/layers/cuda/attention.h

class AttentionImpl : public torch::nn::Module {
 public:
  AttentionImpl() = default;

  AttentionImpl(int num_heads,
                int head_size,
                float scale,
                int num_kv_heads,
                int sliding_window);

  std::tuple<torch::Tensor, std::optional<torch::Tensor>> forward(
      const AttentionMetadata& attn_metadata,
      torch::Tensor& query,
      torch::Tensor& key,
      torch::Tensor& value,
      KVCache& kv_cache);

 private:
  int num_heads_;
  int head_size_;
  float scale_;
  int num_kv_heads_;
  int sliding_window_;
};
TORCH_MODULE(Attention);

xllm-main/xllm/core/layers/cuda/attention.cpp

std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
    const AttentionMetadata& attn_metadata,
    torch::Tensor& query,
    torch::Tensor& key,
    torch::Tensor& value,
    KVCache& kv_cache) {

  // maybe we need to update shared attn state before execute attention,
  // currently we update flashinfer step_wise_attn_state_ at layer 0.
  bool causal = attn_metadata.is_prefill || attn_metadata.is_chunked_prefill;
  flashinfer::update_plan_info(
      attn_metadata.plan_info,
      causal ? xllm::kernel::cuda::determine_attention_backend(
                   /*pos_encoding_mode=*/0,
                   /*use_fp16_qk_reduction=*/false,
                   /*use_custom_mask=*/false)
             : "fa2",
      attn_metadata,
      query.scalar_type(),
      key.scalar_type(),
      output.scalar_type(),
      head_size_,
      head_size_,
      num_heads_,
      num_kv_heads_,
      /*block_size*/ k_cache.size(1),
      /*window_size_left*/ sliding_window_,
      /*enable_cuda_graph*/ false,
      /*causal*/ causal,
      /*use_tensor_core*/ true);

  // TODO: support chunked prefill
  CHECK(!attn_metadata.is_chunked_prefill)
      << "chunked prefill is not supported";
  if (attn_metadata.is_prefill) {
    attention_params.key = key;
    attention_params.value = value;
    xllm::kernel::batch_prefill(attention_params);
  } else {
    attention_params.query = query;
    attention_params.output = output;
    attention_params.k_cache = k_cache;
    attention_params.v_cache = v_cache;

    attention_params.kv_seq_lens = attn_metadata.kv_seq_lens;
    attention_params.paged_kv_indptr = attn_metadata.paged_kv_indptr;
    attention_params.paged_kv_indices = attn_metadata.paged_kv_indices;
    attention_params.paged_kv_last_page_len =
        attn_metadata.paged_kv_last_page_len;

    xllm::kernel::batch_decode(attention_params);
  }

  output = output.view({-1, num_heads_ * head_size_});
  return {output, output_lse};
}

 cuda上使用flashinfo提供的page attention核函数。
 flashinfer::update_plan_info设置调用算子的uri。

std::string get_batch_prefill_uri(const std::string& backend,
                                  torch::ScalarType dtype_q,
                                  torch::ScalarType dtype_kv,
                                  torch::ScalarType dtype_o,
                                  torch::ScalarType dtype_idx,
                                  int64_t head_dim_qk,
                                  int64_t head_dim_vo,
                                  int64_t pos_encoding_mode,
                                  bool use_sliding_window,
                                  bool use_logits_soft_cap,
                                  bool use_fp16_qk_reduction) {
  std::ostringstream oss;
  oss << "batch_prefill_with_kv_cache_"
      << "dtype_q_" << filename_safe_dtype_map.at(dtype_q) << "_"
      << "dtype_kv_" << filename_safe_dtype_map.at(dtype_kv) << "_"
      << "dtype_o_" << filename_safe_dtype_map.at(dtype_o) << "_"
      << "dtype_idx_" << filename_safe_dtype_map.at(dtype_idx) << "_"
      << "head_dim_qk_" << head_dim_qk << "_"
      << "head_dim_vo_" << head_dim_vo << "_"
      << "posenc_" << pos_encoding_mode << "_"
      << "use_swa_" << (use_sliding_window ? "True" : "False") << "_"
      << "use_logits_cap_" << (use_logits_soft_cap ? "True" : "False") << "_"
      << "f16qk_" << (use_fp16_qk_reduction ? "True" : "False");

  if (backend == "fa3") oss << "_sm90";

  return oss.str();
}

算子层api

batch_decode
xllm-main/xllm/core/kernels/ops_api.cpp

void batch_decode(AttentionParams& params) {
#if defined(USE_MLU)
  mlu::batch_decode(params.query,
                    params.k_cache,
                    params.output,
                    params.block_table.value(),
                    params.kv_seq_lens,
                    params.v_cache,
                    params.output_lse,
                    params.q_quant_scale,
                    params.k_cache_quant_scale,
                    params.v_cache_quant_scale,
                    params.out_quant_scale,
                    params.alibi_slope,
                    params.mask,
                    params.compute_dtype,
                    params.max_seq_len,
                    params.window_size_left,
                    params.window_size_right,
                    params.scale,
                    params.return_lse,
                    params.kv_cache_quant_bit_size);
#elif defined(USE_NPU)
  npu::batch_decode(params.query,
                    params.k_cache,
                    params.v_cache.value_or(torch::Tensor()),
                    params.scale,
                    params.block_table.value(),
                    params.seq_lens,
                    params.output);
#elif defined(USE_CUDA)
  cuda::batch_decode(params.uri,
                     params.plan_info,
                     params.float_workspace_buffer,
                     params.int_workspace_buffer,
                     params.page_locked_int_workspace_buffer,
                     params.query,
                     params.k_cache,
                     params.v_cache.value_or(torch::Tensor()),
                     params.paged_kv_indptr,
                     params.paged_kv_indices,
                     params.paged_kv_last_page_len,
                     params.window_size_left,
                     params.scale,
                     params.output,
                     params.output_lse,
                     params.enable_cuda_graph,
                     params.use_tensor_core,
                     params.kv_seq_lens);
#elif defined(USE_ILU)
  ilu::batch_decode(params.query,
                    params.k_cache,
                    params.output,
                    params.block_table.value(),
                    params.kv_seq_lens,
                    params.v_cache,
                    params.output_lse,
                    params.q_quant_scale,
                    params.k_cache_quant_scale,
                    params.v_cache_quant_scale,
                    params.out_quant_scale,
                    params.alibi_slope,
                    params.mask,
                    params.compute_dtype,
                    params.max_seq_len,
                    params.window_size_left,
                    params.window_size_right,
                    params.scale,
                    params.return_lse,
                    params.is_causal,
                    params.kv_cache_quant_bit_size);
#else
  NOT_IMPLEMENTED();
#endif

cuda::batch_decode
xllm-main/xllm/core/kernels/cuda/batch_decode.cpp

void batch_decode(const std::string& uri,
                  torch::Tensor plan_info,
                  torch::Tensor float_workspace_buffer,
                  torch::Tensor int_workspace_buffer,
                  torch::Tensor page_locked_int_workspace_buffer,
                  torch::Tensor query,
                  torch::Tensor k_cache,
                  torch::Tensor v_cache,
                  torch::Tensor paged_kv_indptr,
                  torch::Tensor paged_kv_indices,
                  torch::Tensor paged_kv_last_page_len,
                  int64_t window_left,
                  double sm_scale,
                  torch::Tensor output,
                  std::optional<torch::Tensor>& output_lse,
                  bool enable_cuda_graph,
                  bool use_tensor_core,
                  torch::Tensor kv_seq_lens) {
  if (use_tensor_core) {
    const int64_t batch_size = paged_kv_last_page_len.size(0);
    torch::Tensor qo_indptr_host =
        get_cache_buffer(batch_size + 1, torch::kCPU);
    torch::Tensor qo_indptr = qo_indptr_host.to(torch::kCUDA);

    FunctionFactory::get_instance().fa2_prefill_paged_run_func(uri).call(
        float_workspace_buffer,
        int_workspace_buffer,
        plan_info,
        query,
        k_cache,
        v_cache,
        qo_indptr,
        paged_kv_indptr,
        paged_kv_indices,
        paged_kv_last_page_len,
        output,
        output_lse,
        /*mask_mode_code=*/0,  // NON_CAUSAL
        /*kv_layout_code=*/0,  // NHD layout
        window_left,
        support_pdl(),
        /*maybe_custom_mask=*/std::optional<torch::Tensor>(),
        /*maybe_mask_indptr=*/std::optional<torch::Tensor>(),
        /*maybe_alibi_slopes=*/std::optional<torch::Tensor>(),
        /*maybe_prefix_len_ptr=*/std::optional<torch::Tensor>(),
        /*maybe_token_pos_in_items_ptr=*/std::optional<torch::Tensor>(),
        /*maybe_max_item_len_ptr=*/std::optional<torch::Tensor>(),
        /*logits_soft_cap=*/0.0,
        sm_scale,
        /*rope_rcp_scale=*/1.0,
        /*rope_rcp_theta=*/1.0 / 10000.0,
        /*token_pos_in_items_len=*/0);
  } else {
    FunctionFactory::get_instance().decode_run_func(uri).call(
        float_workspace_buffer,
        int_workspace_buffer,
        plan_info,
        query,
        k_cache,
        v_cache,
        paged_kv_indptr,
        paged_kv_indices,
        paged_kv_last_page_len,
        output,
        output_lse,
        /*kv_layout_code=*/0,  // NHD layout
        window_left,
        support_pdl(),
        /*maybe_alibi_slopes=*/std::optional<torch::Tensor>(),
        /*logits_soft_cap=*/0.0,
        sm_scale,
        /*rope_rcp_scale=*/1.0,
        /*rope_rcp_theta=*/1.0 / 10000.0);
  }
}

 fa2_prefill_paged_run_func

  FA2_PREFILL_PAGED_RUN_FUNC_TYPE fa2_prefill_paged_run_func(
      const std::string& uri) {
    static std::optional<FA2_PREFILL_PAGED_RUN_FUNC_TYPE> f;
    static std::unique_ptr<torch::DynamicLibrary> lib;
    if (f.has_value()) {
      return f.value();
    }

    static std::once_flag flag;
    std::call_once(flag, [&uri]() {
      lib = std::make_unique<torch::DynamicLibrary>(
          path_to_uri_so_lib(uri).c_str(), nullptr, true);
      std::string run_schema_name = uri + "::paged_run";
      f = torch::Dispatcher::singleton()
              .findSchemaOrThrow(run_schema_name.c_str(), "")
              .typed<void(torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          torch::Tensor,
                          std::optional<torch::Tensor>,
                          int64_t,
                          int64_t,
                          int64_t,
                          bool,
                          std::optional<torch::Tensor>,
                          std::optional<torch::Tensor>,
                          std::optional<torch::Tensor>,
                          std::optional<torch::Tensor>,
                          std::optional<torch::Tensor>,
                          std::optional<torch::Tensor>,
                          double,
                          double,
                          double,
                          double,
                          int64_t)>();
    });

    return f.value();
  }

 torch::DynamicLibrary加载动态库。动态库路径解析

std::string path_to_uri_so_lib(const std::string& uri) {
  return util::get_string_env("FLASHINFER_OPS_PATH") + "/" + uri + "/" + uri +
         ".so";
}
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐