相关材料

[1] pd分离在vllm中用法
[2]vLLM PD分离方案入门:核心概念、优势与适应场景梳理

DisaggPDScheduler

 xllm在pd分离场景中,主要逻辑集中在DisaggPDScheduler。

DisaggPDScheduler::DisaggPDScheduler(Engine* engine, const Options& options)
    : ContinuousScheduler(engine, options), server_name_("DisaggPDServer") {
  if (!options_.instance_role().has_value()) {
    LOG(FATAL) << "Instance type is not set in disagg pd mode.";
  }

  // Only initialize for non-OOC mode
  // OOC mode (PDOOCScheduler) will handle initialization in its own constructor
  if (!options_.enable_pd_ooc()) {
    // Start dispatch thread for prefill instance
    dispatch_thread_ = std::make_unique<std::thread>(
        &DisaggPDScheduler::dispatch_requests, this);

    // Start RPC server thread
    server_name_.append(std::to_string(options.server_idx()));
    rpc_server_thread_ = std::make_unique<std::thread>(
        &DisaggPDScheduler::start_rpc_server, this);
    initialize_rpc_server_and_client(server_name_);
    register_instance_info(server_name_, engine);

    // Profile ttft & topt and update instance info (for mix instances)
    if (!options_.disable_ttft_profiling() &&
        options_.instance_role().value() == InstanceRole::MIX) {
      profile_ttft();
      profile_tpot();
    }
  }
}

 在P实例,dispatch_thread_ 负载prefill的调度。

bool DisaggPDScheduler::add_request(std::shared_ptr<Request>& request) {
  CHECK(request != nullptr);
  CHECK(!request->sequences().empty());

  kv_cache_manager_->prefetch_from_storage(request);

  if (request->offline()) {
    // offline request, push to offline queue
    prefill_request_queue_offline_.enqueue(request);
    return true;
  }
  // push and wait
  prefill_request_queue_.enqueue(request);

  return true;
}

// prefill send new request to remote instance
void DisaggPDScheduler::dispatch_requests() {
  while (true) {
    const auto timeout = std::chrono::milliseconds(100);
    // Wait for online request until timeout.
    // If timeout, try to get offline request once. If no offline request,
    // continue to wait for online request. This can avoid offline request
    // blocking online request for too long time.
    std::shared_ptr<Request> request;
    if (!prefill_request_queue_.wait_dequeue_timed(request, timeout)) {
      if (!prefill_request_queue_offline_.try_dequeue(request)) {
        continue;
      }
    }

    if (request == nullptr) {
      // nullptr is a signal to exit
      break;
    }

    std::vector<std::shared_ptr<Request>> requests;
    requests.emplace_back(request);
    std::string selected_instance = "";
    proto::DisaggPDService_Stub* stub = nullptr;

    if (selected_instance.empty() && !stub) {
      // get allocated decode instance list from Master
      while (decode_inst_names_.empty()) {
        decode_inst_names_ = xservice_client_->get_static_decode_list();
        if (!decode_inst_names_.empty()) {
          LOG(INFO) << "Get PD decode instance list: "
                    << absl::StrJoin(decode_inst_names_, "; ");
          break;
        }
        sleep(1);
      }
      // select a D instance use RR currently.
      // TODO: use better decode selection strategy later. maybe different
      // strategy for offline and online request. or implement in xllm service.
      int try_decode_count = 0;
      while (!stub) {
        if (try_decode_count == decode_inst_names_.size()) {
          LOG(FATAL) << "Can not connect to all decode instances.";
        }
        ++try_decode_count;
        selected_instance = decode_inst_names_[current_decode_idx_];
        current_decode_idx_ =
            (++current_decode_idx_) % decode_inst_names_.size();
        stub = create_rpc_channel(selected_instance);
      }
    }

    {
      std::lock_guard<std::mutex> lock(req_to_channel_map_mutex_);
      for (auto& req : requests) {
        req_to_channel_map_[req->request_id()] = stub;
      }
    }

    // TODO: send the request to the selected D instance
    // Send 'DisaggRequests' and recv 'DisaggResponses'
    xllm::proto::DisaggRequests reqs;
    xllm::proto::DisaggResponses resps;
    // prefill name (ID)
    reqs.set_prefill_name(xservice_client_->get_instance_name());
    reqs.mutable_reqs()->Reserve(requests.size());
    // currently we only support one request once.
    for (size_t i = 0; i < requests.size(); ++i) {
      // proto::DisaggRequest req;
      auto req = reqs.mutable_reqs()->Add();
      req->set_req_id(requests[i]->request_id());
      req->set_service_req_id(requests[i]->service_request_id());
      req->set_tokens_num(requests[i]->state().prompt_tokens.size());
      req->set_prompt(requests[i]->state().prompt);
      ADD_VECTOR_TO_PROTO(req->mutable_prompt_tokens(),
                          requests[i]->state().prompt_tokens);
      req->set_stream(requests[i]->state().stream);
      req->set_x_request_id(requests[i]->x_request_id());
      req->set_x_request_time(requests[i]->x_request_time());
      req->set_seq_capacity(requests[i]->state().seq_capacity);
      req->set_max_tokens(
          requests[i]->state().stopping_checker.get_max_generated_tokens());
      req->set_max_context_len(
          requests[i]->state().stopping_checker.get_max_context_len());
      req->set_ignore_eos(
          requests[i]->state().stopping_checker.get_ignore_eos());
      req->set_eos_token_id(
          requests[i]->state().stopping_checker.get_eos_token());
      if (requests[i]->state().stopping_checker.get_stop_tokens().size() > 0) {
        ADD_VECTOR_TO_PROTO(
            req->mutable_stop_token_ids(),
            requests[i]->state().stopping_checker.get_stop_tokens());
      }
      if (requests[i]->state().stopping_checker.get_stop_sequences().size() >
          0) {
        for (auto& stop_sequence :
             requests[i]->state().stopping_checker.get_stop_sequences()) {
          // proto::StopSequence proto_seq;
          auto proto_seq = req->mutable_stop_sequences()->Add();
          ADD_VECTOR_TO_PROTO(proto_seq->mutable_seq_tokens(), stop_sequence);
          //*req->mutable_stop_sequences()->Add() = proto_seq;
        }
      }
      req->set_n(requests[i]->state().n);
      req->set_best_of(requests[i]->state().best_of);
      req->set_frequency_penalty(
          requests[i]->state().sampling_param.frequency_penalty);
      req->set_presence_penalty(
          requests[i]->state().sampling_param.presence_penalty);
      req->set_repetition_penalty(
          requests[i]->state().sampling_param.repetition_penalty);
      req->set_temperature(requests[i]->state().sampling_param.temperature);
      req->set_top_p(requests[i]->state().sampling_param.top_p);
      req->set_top_k(requests[i]->state().sampling_param.top_k);
      req->set_logprobs(requests[i]->state().sampling_param.logprobs);
      req->set_top_logprobs(requests[i]->state().sampling_param.top_logprobs);
      req->set_is_embeddings(requests[i]->state().sampling_param.is_embeddings);
      req->set_echo(requests[i]->state().echo);
      req->set_skip_special_tokens(requests[i]->state().skip_special_tokens);
      //*reqs.mutable_reqs()->Add() = req;
    }
    std::vector<std::string> device_ips;
    std::vector<uint16_t> ports;
    engine_->get_device_info(device_ips, ports);
    reqs.mutable_cluster_infos()->mutable_cluster_ids()->Add(
        instance_info_.cluster_ids.begin(), instance_info_.cluster_ids.end());
    reqs.mutable_cluster_infos()->mutable_addrs()->Add(
        instance_info_.addrs.begin(), instance_info_.addrs.end());
    reqs.mutable_cluster_infos()->mutable_device_ips()->Add(device_ips.begin(),
                                                            device_ips.end());
    reqs.mutable_cluster_infos()->mutable_ports()->Add(ports.begin(),
                                                       ports.end());
    reqs.mutable_cluster_infos()->set_dp_size(options_.dp_size());

    // TODO: sync rpc here currently
    brpc::Controller cntl;
    stub->AddNewRequests(&cntl, &reqs, &resps, nullptr);
    // TODO: error handler
    // if (rpc failed) {
    //  // push all request back to prefill_request_queue_
    //}

    // check reqs which can not dispatch to D instance,
    // and push back to prefill_request_queue_
    CHECK_EQ(requests.size(), resps.resps().size())
        << "selected_instance : " << selected_instance;
    // insert instance name to linked_instance_
    {
      std::lock_guard<std::mutex> lock(linked_instances_mutex_);
      linked_instance_.emplace(selected_instance);
    }
    for (size_t i = 0; i < requests.size(); ++i) {
      if (resps.resps()[i].status_code() != 200) {
        // push back to prefill_request_queue_
        if (requests[i]->offline()) {
          prefill_request_queue_offline_.enqueue(requests[i]);
        } else {
          prefill_request_queue_.enqueue(requests[i]);
        }

      } else {
        for (auto& sequence : requests[i]->sequences()) {
          TransferKVInfo info;
          info.request_id = requests[i]->request_id();
          for (auto& bid : resps.resps()[i].blocks_ids()) {
            info.remote_blocks_ids.emplace_back(bid);
          }
          info.dp_rank = resps.resps()[i].dp_rank();
          // TODO: remote_instances_info_ is not multi-thread safe.
          info.remote_instance_info = remote_instances_info_[selected_instance];
          sequence->kv_state().set_transfer_kv_info(std::move(info));
        }

        // push to request_queue_, and will be executed by engine.
        request_queue_.write(requests[i]);
      }
    }
  }
}

 add_request将request放入prefill_request_queue_。
 在dispatch_requests函数中,创建stub = create_rpc_channel。decode_address有两个来源:1 requests中携带的,2 使用xservice_client_->get_static_decode_list()获取。
 xservice_client_中需要同etcd交互。
 XllmRpcService定义的接口。

service XllmRpcService {
  rpc RegisterInstance(InstanceMetaInfo) returns (StatusCode) {}
  rpc GetInstanceInfo(InstanceID) returns (InstanceMetaInfo) {}
  rpc Heartbeat(HeartbeatRequest) returns (Status) {}
  rpc GetStaticDecodeList(InstanceID) returns (InstanceIDs) {}
  rpc GetStaticPrefillList(InstanceID) returns (InstanceIDs) {}
  // xllm service receive response from decode instance directly in disagg pd mode.
  // This can eliminate the cost brought by forwarding through prefill.
  rpc Generations(xllm.proto.DisaggStreamGenerations) returns (xllm.proto.StatusSet) {}
}

 服务由另一个工程xllm-service提供。
 stub->AddNewRequests(&cntl, &reqs, &resps, nullptr),发起远程调用。对应的服务接口为DisaggPDService::AddNewRequests。
 reqs.mutable_cluster_info设置P实例集群的device_ips,ports。
 在D实例,rpc_server_thread_ 负责启动DisaggPDService。

void DisaggPDScheduler::start_rpc_server() {
  std::unique_ptr<DisaggPDService> service =
      std::make_unique<DisaggPDService>(this, engine_);
  auto rpc_server =
      ServerRegistry::get_instance().register_server(server_name_);
  if (!rpc_server->start(std::move(service))) {
    LOG(ERROR) << "Failed to start brpc disagg pd server on port "
               << FLAGS_disagg_pd_port;
    return;
  }
}

class DisaggPDService : public proto::DisaggPDService {};

DisaggPDService

 定义的RPC接口:

service DisaggPDService {
  rpc Generation(DisaggStreamGeneration) returns (Status) {}
  rpc Generations(DisaggStreamGenerations) returns (StatusSet) {}
  rpc AddNewRequests(DisaggRequests) returns (DisaggResponses) {}
  rpc FirstGeneration(DisaggGenerationsRequests) returns (Status) {}
  rpc MultiGenerations(DisaggGenerationsRequests) returns (Status) {}
  rpc SendPullSignal(PullSignal) returns (Status) {}
}

DisaggPDService::AddNewRequests

void DisaggPDService::AddNewRequests(
    ::google::protobuf::RpcController* controller,
    const proto::DisaggRequests* request,
    proto::DisaggResponses* response,
    ::google::protobuf::Closure* done) {
  brpc::ClosureGuard done_guard(done);
  // try to allocate blocks for new requests
  disagg_pd_service_impl_->decode_recv_new_requests(request, response);
}

void DisaggPDServiceImpl::decode_recv_new_requests(
    const proto::DisaggRequests* request,
    proto::DisaggResponses* response) {
  // link prefill instance
  if (!scheduler_->is_instance_linked(request->prefill_name())) {
    std::vector<uint64_t> cluster_ids(
        request->cluster_infos().cluster_ids().begin(),
        request->cluster_infos().cluster_ids().end());
    std::vector<std::string> addrs(request->cluster_infos().addrs().begin(),
                                   request->cluster_infos().addrs().end());
    std::vector<std::string> device_ips(
        request->cluster_infos().device_ips().begin(),
        request->cluster_infos().device_ips().end());
    std::vector<uint16_t> ports(request->cluster_infos().ports().begin(),
                                request->cluster_infos().ports().end());
    int32_t dp_size = request->cluster_infos().dp_size();
    if (!scheduler_->link_instance(request->prefill_name(),
                                   cluster_ids,
                                   addrs,
                                   device_ips,
                                   ports,
                                   dp_size)) {
      LOG(ERROR) << "Link instance failed, instance name : "
                 << request->prefill_name();
      return;
    }
  }

  for (auto& req : request->reqs()) {
    auto resp = response->add_resps();
    resp->set_req_id(req.req_id());

    auto new_request = generate_request(req);
    if (new_request == nullptr) {
      resp->set_status_code(500);
      continue;
    }

    auto& sequences = new_request->sequences();
    Sequence* sequence = sequences[0].get();

    if (!scheduler_->try_allocate(sequence)) {
      // FIXME: set status code
      resp->set_status_code(404);
    } else {
      // push the request to scheduler request buffer
      bool success =
          scheduler_->decode_schedule(new_request, request->prefill_name());
      if (!success) {
        LOG(ERROR) << "Failed to schedule new decode instance request: "
                   << req.req_id();
        // request and blocks are released in scheduler
        resp->set_status_code(500);
      }

      auto dp_rank = sequence->dp_rank();
      resp->set_dp_rank(dp_rank);

      size_t shared_num = sequence->kv_state().shared_kv_blocks_num();
      auto blocks = sequence->kv_state().kv_blocks();
      for (size_t i = shared_num; i < blocks.size(); i++) {
        *(resp->mutable_blocks_ids()->Add()) = blocks[i].id();
      }

      resp->set_status_code(200);
    }
  }
}

scheduler_->link_instance

D实例同P实例建立链接

bool DisaggPDScheduler::link_instance(
    const std::string& instance_name,
    const std::vector<uint64_t>& cluster_ids,
    const std::vector<std::string>& addrs,
    const std::vector<std::string>& device_ips,
    const std::vector<uint16_t>& ports,
    const int32_t dp_size) {
  std::lock_guard<std::mutex> lock(linked_instances_mutex_);
  if (!engine_->link_cluster(cluster_ids, addrs, device_ips, ports, dp_size)) {
    LOG(ERROR) << "Link cluster failed!";
    return false;
  }
  linked_instance_.emplace(instance_name);
  return true;
}

bool LLMEngine::link_cluster(const std::vector<uint64_t>& cluster_ids,
                             const std::vector<std::string>& addrs,
                             const std::vector<std::string>& device_ips,
                             const std::vector<uint16_t>& ports,
                             const int32_t src_dp_size) {
  // Indicate which worker in the dp group in prefill the current worker needs
  // to connect to. First, we connect the rank 0 workers in each DP. Then,
  // increment the ranks sequentially.
  int32_t src_dp_worker_index = 0;
  int32_t src_world_size = cluster_ids.size();
  int32_t src_tp_size = src_world_size / src_dp_size;

  std::vector<folly::SemiFuture<bool>> futures;
  futures.reserve(worker_clients_num_);
  for (size_t worker_rank = 0; worker_rank < worker_clients_num_;
       ++worker_rank) {
    // The worker for decoding needs to establish a connection for each dp group
    // in prefill.
    std::vector<uint64_t> dp_cluster_ids;
    std::vector<std::string> dp_addrs;
    std::vector<std::string> dp_device_ips;
    std::vector<uint16_t> dp_ports;
    dp_cluster_ids.reserve(src_dp_size);
    dp_addrs.reserve(src_dp_size);
    dp_device_ips.reserve(src_dp_size);
    dp_ports.reserve(src_dp_size);
    for (int32_t i = 0; i < src_dp_size; ++i) {
      int32_t src_worker_index = i * src_tp_size + src_dp_worker_index;
      dp_cluster_ids.emplace_back(cluster_ids[src_worker_index]);
      dp_addrs.emplace_back(addrs[src_worker_index]);
      dp_device_ips.emplace_back(device_ips[src_worker_index]);
      dp_ports.emplace_back(ports[src_worker_index]);
    }
    // Increment the rank.
    src_dp_worker_index = (src_dp_worker_index + 1) % src_tp_size;

    folly::Promise<bool> promise;
    auto future = promise.getSemiFuture();
    link_threadpool_->schedule([this,
                                promise = std::move(promise),
                                worker_rank,
                                dp_cluster_ids = std::move(dp_cluster_ids),
                                dp_addrs = std::move(dp_addrs),
                                dp_device_ips = std::move(dp_device_ips),
                                dp_ports = std::move(dp_ports)]() mutable {
      promise.setValue(worker_clients_[worker_rank]->link_cluster(
          dp_cluster_ids, dp_addrs, dp_device_ips, dp_ports));
    });
    futures.emplace_back(std::move(future));
  }

  // wait for all futures to complete
  auto results = folly::collectAll(futures).get();
  for (const auto& result : results) {
    if (!result.value()) {
      LOG(ERROR) << "Link cluster failed.";
      return false;
    }
  }
  return true;
}

prefill_send_first_generation

 P实例通知D拉取KV cache。

void DisaggPDScheduler::prefill_send_first_generation() {
  if (running_sequences_.size() == 0) {
    return;
  }

  std::vector<std::shared_ptr<Request>> requests;
  std::vector<std::shared_ptr<Request>> non_stream_requests;
  requests.reserve(running_requests_.size());
  non_stream_requests.reserve(running_requests_.size());
  for (size_t i = 0; i < running_requests_.size(); ++i) {
    auto request = running_requests_[i];
    // Check if the request is a recently completed prefill request
    if (request->sequences()[0]->num_generated_tokens() == 1) {
      requests.emplace_back(request);
      if (!request->state().stream) {
        non_stream_requests.emplace_back(request);
      }
      running_requests_[i] = nullptr;
    }
  }
  // call non_stream_request's callback in P instance when its prefill ends
  response_processor_->process_completed_requests(non_stream_requests);

  // No prefill request needs to be transferred to decode.
  if (requests.size() == 0) {
    return;
  }

  prefill_threadpool_.schedule([this,
                                requests = std::move(requests)]() mutable {
    // send request first token to remote instance
    // TODO: here we only support one sequence for now.
    for (auto& request : requests) {
      // TODO: support batch request later
      proto::DisaggGenerationsRequests gens;
      auto gen = gens.mutable_multi_gens()->Add();
      gen->set_req_id(request->request_id());
      if (request->sequences()[0]->first_token().has_value()) {
        auto token = gen->mutable_tokens()->Add();
        token->set_token_id(
            request->sequences()[0]->first_token().value().token_id);
        if (request->sequences()[0]
                ->first_token()
                .value()
                .token_logprob.has_value()) {
          token->set_logprob(request->sequences()[0]
                                 ->first_token()
                                 .value()
                                 .token_logprob.value());
          token->set_has_logprob(true);
        } else {
          token->set_has_logprob(false);
        }
        ADD_VECTOR_TO_PROTO(
            token->mutable_top_tokens(),
            request->sequences()[0]->first_token().value().token_top_tokens);
        ADD_VECTOR_TO_PROTO(
            token->mutable_top_logprobs(),
            request->sequences()[0]->first_token().value().token_top_logprobs);
      }
      gen->set_kv_cache_transfer_mode(options_.kv_cache_transfer_mode());
      if (options_.kv_cache_transfer_mode() == "PULL") {
        ADD_VECTOR_TO_PROTO(gen->mutable_cluster_ids(),
                            instance_info_.cluster_ids);
        ADD_VECTOR_TO_PROTO(gen->mutable_addrs(), instance_info_.addrs);
        ADD_VECTOR_TO_PROTO(gen->mutable_k_cache_ids(),
                            instance_info_.k_cache_ids);
        ADD_VECTOR_TO_PROTO(gen->mutable_v_cache_ids(),
                            instance_info_.v_cache_ids);

        const auto blocks = request->sequences()[0]->kv_state().kv_blocks();
        std::vector<uint64_t> block_ids;
        block_ids.reserve(blocks.size());
        for (const auto& block : blocks) {
          block_ids.push_back(block.id());
        }
        ADD_VECTOR_TO_PROTO(gen->mutable_block_ids(), block_ids);
        gen->set_dp_size(instance_info_.dp_size);
        gen->set_dp_rank(request->sequences()[0]->dp_rank());
      }

      // send first gens to remote instance
      proto::DisaggPDService_Stub* stub = nullptr;
      {
        std::lock_guard<std::mutex> lock(req_to_channel_map_mutex_);
        // now we only support one request once.
        stub = req_to_channel_map_[request->request_id()];
      }

      // TODO: Async call later
      proto::Status resp;
      brpc::Controller cntl;
      stub->FirstGeneration(&cntl, &gens, &resp, nullptr);

      if (cntl.Failed() || !resp.ok()) {
        LOG(ERROR) << "Failed to send first generation, " << cntl.ErrorText()
                   << ", staus: " << resp.ok();
      }

      {
        std::lock_guard<std::mutex> lock(req_to_channel_map_mutex_);
        req_to_channel_map_.erase(request->request_id());
      }
      kv_cache_manager_->deallocate(request.get());
    }
  });
}

 stub->FirstGeneration,向D实例发起远程调用。

DisaggPDService::FirstGeneration

D实例,KV cache拉取过程。

// TODO: support embedding later, now we only support tokens
void DisaggPDService::FirstGeneration(
    ::google::protobuf::RpcController* controller,
    const proto::DisaggGenerationsRequests* request,
    proto::Status* response,
    ::google::protobuf::Closure* done) {
  // Receive first token from Prefill, schedule the request to running queue
  brpc::ClosureGuard done_guard(done);
  disagg_pd_service_impl_->decode_recv_first_generation(request, response);
}

// TODO: support embedding later, now we only support tokens
void DisaggPDServiceImpl::decode_recv_first_generation(
    const proto::DisaggGenerationsRequests* request,
    proto::Status* response) {
  // TODO: we only support one request generation currently
  for (auto& gen : request->multi_gens()) {
    // Process the first token from the tokens array
    if (gen.tokens().empty()) {
      response->set_ok(false);
      return;
    }
    std::vector<std::string> addrs(gen.addrs().begin(), gen.addrs().end());
    bool success =
        scheduler_->decode_recv_first_generation(gen.req_id(),
                                                 first_token.token_id(),
                                                 first_token.has_logprob(),
                                                 first_token.logprob(),
                                                 std::move(top_tokens),
                                                 std::move(top_logprobs),
                                                 gen.kv_cache_transfer_mode(),
                                                 std::move(cluster_ids),
                                                 std::move(addrs),
                                                 std::move(k_cache_ids),
                                                 std::move(v_cache_ids),
                                                 std::move(block_ids),
                                                 gen.dp_size(),
                                                 gen.dp_rank());

  }
}

bool DisaggPDScheduler::decode_recv_first_generation(
    const std::string& req_id,
    int64_t token_id,
    bool has_logprob,
    float logprob,
    std::vector<int64_t> top_tokens,
    std::vector<float> top_logprobs,
    const std::string& kv_cache_transfer_mode,
    std::vector<uint64_t> src_cluster_ids,
    std::vector<std::string> src_addrs,
    std::vector<int64_t> src_k_cache_ids,
    std::vector<int64_t> src_v_cache_ids,
    std::vector<uint64_t> src_block_ids,
    int32_t src_dp_size,
    int32_t src_dp_rank) {

  // push to request_queue_, and will be executed by engine.
  std::shared_ptr<Request> request = nullptr;
  {
    std::lock_guard<std::mutex> lock(received_request_map_mutex_);
    auto it = received_request_map_.find(req_id);
    if (it == received_request_map_.end()) {
      LOG(ERROR) << "Failed to find request, request id: " << req_id;
      return false;
    }
    request = it->second;
    received_request_map_.erase(it);
  }

  // pull kv cache
  if (kv_cache_transfer_mode == "PULL") {
    const auto blocks = request->sequences()[0]->kv_state().kv_blocks();
    std::vector<uint64_t> dst_block_ids;
    dst_block_ids.reserve(blocks.size());
    for (const auto& block : blocks) {
      dst_block_ids.push_back(block.id());
    }

    int32_t dst_dp_rank = request->sequences()[0]->dp_rank();
    engine_->pull_kv_blocks(src_dp_size,
                            src_dp_rank,
                            src_cluster_ids,
                            src_addrs,
                            src_k_cache_ids,
                            src_v_cache_ids,
                            src_block_ids,
                            dst_dp_rank,
                            dst_block_ids);
  }
}

 LLMEngine::pull_kv_blocks, 针对每一个dp组,我觉得P实例上的src_tp_size 同D实例上dst_tp_size是相等的。就是需要kvcache传输的P实例和D实例有相同数量的worker。

bool LLMEngine::pull_kv_blocks(const int32_t src_dp_size,
                               const int32_t src_dp_rank,
                               const std::vector<uint64_t>& src_cluster_ids,
                               const std::vector<std::string>& src_addrs,
                               const std::vector<int64_t>& src_k_cache_ids,
                               const std::vector<int64_t>& src_v_cache_ids,
                               const std::vector<uint64_t>& src_blocks,
                               const int32_t dst_dp_rank,
                               const std::vector<uint64_t>& dst_blocks) {
  int32_t src_world_size = src_cluster_ids.size();
  int32_t src_tp_size = src_world_size / src_dp_size;
  int32_t dst_world_size = options_.nnodes();
  int32_t dst_tp_size = dst_world_size / dp_size_;

  std::vector<bool> results;
  results.reserve(dst_tp_size);
  // Pull the KV cache for all workers in the current DP rank.
  for (size_t tp_rank = 0; tp_rank < dst_tp_size; ++tp_rank) {
    int32_t dst_worker_rank = dst_dp_rank * dst_tp_size + tp_rank;
    // Determine the ranks of the remote workers connected to the current
    // worker.
    int32_t src_dp_worker_rank = dst_worker_rank % src_tp_size;
    int32_t src_worker_rank = src_dp_rank * src_tp_size + src_dp_worker_rank;
    results.push_back(worker_clients_[dst_worker_rank]->pull_kv_blocks(
        src_cluster_ids[src_worker_rank],
        src_addrs[src_worker_rank],
        src_k_cache_ids[src_worker_rank],
        src_v_cache_ids[src_worker_rank],
        src_blocks,
        dst_blocks));
  }

  for (bool result : results) {
    if (!result) {
      return false;
    }
  }
  return true;
}

bool RemoteWorker::pull_kv_blocks(const uint64_t src_cluster_id,
                                  const std::string& src_addr,
                                  const int64_t src_k_cache_id,
                                  const int64_t src_v_cache_id,
                                  const std::vector<uint64_t>& src_blocks,
                                  const std::vector<uint64_t>& dst_blocks) {
  return channel_->pull_kv_blocks(src_cluster_id,
                                  src_addr,
                                  src_k_cache_id,
                                  src_v_cache_id,
                                  src_blocks,
                                  dst_blocks);
}

bool CommChannel::pull_kv_blocks(const uint64_t src_cluster_id,
                                 const std::string& src_addr,
                                 const int64_t src_k_cache_id,
                                 const int64_t src_v_cache_id,
                                 const std::vector<uint64_t>& src_blocks,
                                 const std::vector<uint64_t>& dst_blocks) {
  proto::PullKVCacheRequest request;
  request.set_cluster_id(src_cluster_id);
  request.set_addr(src_addr);
  request.set_k_cache_id(src_k_cache_id);
  request.set_v_cache_id(src_v_cache_id);

  ADD_VECTOR_TO_PROTO(request.mutable_src_blocks(), src_blocks);
  ADD_VECTOR_TO_PROTO(request.mutable_dst_blocks(), dst_blocks);

  proto::Status s;
  brpc::Controller cntl;
  stub_->PullKVCache(&cntl, &request, &s, nullptr);

  return !cntl.Failed() && s.ok();
}


bool CommChannel::pull_kv_blocks(const uint64_t src_cluster_id,
                                 const std::string& src_addr,
                                 const int64_t src_k_cache_id,
                                 const int64_t src_v_cache_id,
                                 const std::vector<uint64_t>& src_blocks,
                                 const std::vector<uint64_t>& dst_blocks) {
    stub_->PullKVCache(&cntl, &request, &s, nullptr);
}

void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
                                const proto::PullKVCacheRequest* req,
                                proto::Status* resp,
                                ::google::protobuf::Closure* done) {
  threadpool_->schedule([this, controller, req, resp, done]() mutable {
    brpc::ClosureGuard done_guard(done);
    uint64_t src_cluster_id = req->cluster_id();
    std::string addr = req->addr();
    int64_t src_k_cache_id = req->k_cache_id();
    int64_t src_v_cache_id = req->v_cache_id();
    std::vector<uint64_t> src_blocks(req->src_blocks().begin(),
                                     req->src_blocks().end());
    std::vector<uint64_t> dst_blocks(req->dst_blocks().begin(),
                                     req->dst_blocks().end());
    auto future = worker_->pull_kv_blocks_async(src_cluster_id,
                                                addr,
                                                src_k_cache_id,
                                                src_v_cache_id,
                                                src_blocks,
                                                dst_blocks);
    bool status = std::move(future).get();
    resp->set_ok(status);
  });
  return;
}

folly::SemiFuture<bool> WorkerImpl::pull_kv_blocks_async(
    uint64_t src_cluster_id,
    const std::string& src_addr,
    int64_t src_k_cache_id,
    int64_t src_v_cache_id,
    const std::vector<uint64_t>& src_blocks,
    const std::vector<uint64_t>& dst_blocks) {
#if defined(USE_NPU)
  return kv_cache_transfer_->pull_kv_blocks_async(src_cluster_id,
                                                  src_addr,
                                                  src_k_cache_id,
                                                  src_v_cache_id,
                                                  src_blocks,
                                                  dst_blocks);
#endif
  return false;
}

 kv_cache_transfer_由KVCacheTransferFactory::create创建,有三种实现。

  • LlmDataDist
  • MooncakeKVCacheTransfer
  • HcclKVCacheTransfer
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐