现在升级为 支持多客户端并发流式推理的 Qwen2Audio 实时语音 LLM 服务


🧠 改造目标

在上一个版本基础上,增加:

  1. Session 管理机制
    每个 WebSocket 连接(或请求)对应一个唯一 Session ID,用于区分不同客户端的缓存、上下文。

  2. 并发安全缓存(Session Cache)
    每个 Session 独立维护自己的:

    • audio_embeds 累积

    • prompt(可定制)

    • model_kv_cache(预留,可扩展)

  3. 并发处理能力(多用户同时推送流式音频)
    利用 FastAPI 的异步并发机制 + asyncio 锁控制。


🚀 一、服务端完整实现(server_vllm_audio_streaming_session.py

import torch
import torchaudio
import asyncio
import json
import uuid
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from transformers import Qwen2AudioProcessor, Qwen2AudioEncoder, Qwen2ForCausalLM

# ============================================================
# 初始化 FastAPI 应用
# ============================================================
app = FastAPI(title="Qwen2Audio Streaming LLM Server (Multi-Session)")

# ============================================================
# 模型加载
# ============================================================
MODEL_PATH = "Qwen/Qwen2-Audio-7B-Instruct"

print(f"[Init] Loading model from {MODEL_PATH} ...")
processor = Qwen2AudioProcessor.from_pretrained(MODEL_PATH)
audio_tower = Qwen2AudioEncoder.from_pretrained(MODEL_PATH)
language_model = Qwen2ForCausalLM.from_pretrained(
    MODEL_PATH, torch_dtype=torch.float16, device_map="auto"
)

# 简单 projector
class AudioProjector(torch.nn.Module):
    def __init__(self, audio_dim, text_dim):
        super().__init__()
        self.linear = torch.nn.Linear(audio_dim, text_dim)
    def forward(self, x):
        return self.linear(x)

projector = AudioProjector(audio_tower.config.hidden_size, language_model.config.hidden_size).to("cuda")

# ============================================================
# 会话缓存结构
# ============================================================
class SessionCache:
    """每个客户端 WebSocket 连接对应的状态缓存"""
    def __init__(self):
        self.sessions = {}  # {session_id: {"embeds": [], "prompt": str, "lock": asyncio.Lock()}}

    def create(self, prompt: str):
        sid = str(uuid.uuid4())
        self.sessions[sid] = {
            "embeds": [],
            "prompt": prompt,
            "lock": asyncio.Lock(),
        }
        print(f"[Session] Created: {sid}")
        return sid

    def get(self, sid: str):
        return self.sessions.get(sid)

    def remove(self, sid: str):
        if sid in self.sessions:
            del self.sessions[sid]
            print(f"[Session] Removed: {sid}")

session_cache = SessionCache()

# ============================================================
# HTTP 整段推理(非流式)
# ============================================================
@app.post("/infer")
async def infer_audio(file_path: str, prompt: str = "Transcribe this: <|AUDIO|>"):
    waveform, sr = torchaudio.load(file_path)
    inputs = processor(audios=waveform[0], sampling_rate=sr, return_tensors="pt").to("cuda")
    features = audio_tower(**inputs).last_hidden_state
    embeds = projector(features)

    input_ids = language_model.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
    outputs = language_model.generate(inputs_embeds=embeds, max_new_tokens=64)
    text = language_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
    return JSONResponse({"text": text})

# ============================================================
# WebSocket 流式接口(支持多会话)
# ============================================================
@app.websocket("/stream")
async def stream_audio(websocket: WebSocket):
    await websocket.accept()

    sid = None
    print("[WS] Client connected")

    try:
        while True:
            msg = await websocket.receive_text()
            data = json.loads(msg)

            # --- 1️⃣ 创建新会话 ---
            if data["type"] == "start":
                prompt = data.get("prompt", "Transcribe this: <|AUDIO|>")
                sid = session_cache.create(prompt)
                await websocket.send_text(json.dumps({"status": "started", "session_id": sid}))

            # --- 2️⃣ 接收音频块 ---
            elif data["type"] == "chunk" and sid:
                sess = session_cache.get(sid)
                if sess is None:
                    await websocket.send_text(json.dumps({"error": "Invalid session"}))
                    continue

                chunk = torch.tensor(data["pcm"], dtype=torch.float32).to("cuda")
                async with sess["lock"]:
                    inputs = processor(audios=chunk, sampling_rate=16000, return_tensors="pt").to("cuda")
                    features = audio_tower(**inputs).last_hidden_state
                    projected = projector(features)
                    sess["embeds"].append(projected)

                await websocket.send_text(json.dumps({"status": "ok", "session_id": sid, "chunk_len": len(chunk)}))

            # --- 3️⃣ 结束并生成结果 ---
            elif data["type"] == "end" and sid:
                sess = session_cache.get(sid)
                if sess is None:
                    await websocket.send_text(json.dumps({"error": "Invalid session"}))
                    continue

                async with sess["lock"]:
                    all_embeds = torch.cat(sess["embeds"], dim=1)
                    sess["embeds"].clear()

                    input_ids = language_model.tokenizer(sess["prompt"], return_tensors="pt").input_ids.to("cuda")
                    outputs = language_model.generate(inputs_embeds=all_embeds, max_new_tokens=64)
                    text = language_model.tokenizer.decode(outputs[0], skip_special_tokens=True)

                await websocket.send_text(json.dumps({"status": "done", "session_id": sid, "text": text}))
                session_cache.remove(sid)
                await websocket.close()
                break

    except WebSocketDisconnect:
        print(f"[WS] Disconnected session {sid}")
        if sid:
            session_cache.remove(sid)
    except Exception as e:
        print("[Error]", e)
        if sid:
            session_cache.remove(sid)
        await websocket.close()

# 启动命令:
# uvicorn server_vllm_audio_streaming_session:app --host 0.0.0.0 --port 8080

🧩 二、客户端示例(支持多用户同时连接)

import asyncio
import websockets
import torchaudio
import json

async def stream_audio(file_path):
    uri = "ws://127.0.0.1:8080/stream"
    async with websockets.connect(uri) as ws:
        print("[Client] Connected")

        # 1️⃣ 启动新会话
        await ws.send(json.dumps({"type": "start", "prompt": "Transcribe this: <|AUDIO|>"}))
        resp = await ws.recv()
        start_info = json.loads(resp)
        sid = start_info["session_id"]
        print("[Client] Session:", sid)

        # 2️⃣ 发送音频块
        waveform, sr = torchaudio.load(file_path)
        waveform = waveform[0]
        chunk_size = sr * 5  # 5秒块

        for i in range(0, len(waveform), chunk_size):
            chunk = waveform[i:i + chunk_size].tolist()
            await ws.send(json.dumps({"type": "chunk", "session_id": sid, "pcm": chunk}))
            resp = await ws.recv()
            print("[Ack]", resp)

        # 3️⃣ 发送结束
        await ws.send(json.dumps({"type": "end", "session_id": sid}))
        result = await ws.recv()
        print("[Final]", result)

async def main():
    await asyncio.gather(
        stream_audio("sample1.wav"),
        stream_audio("sample2.wav"),
    )

asyncio.run(main())

⚙️ 三、运行流程

🧱 1️⃣ 启动服务端

uvicorn server_vllm_audio_streaming_session:app --host 0.0.0.0 --port 8080

🧩 2️⃣ 启动多个客户端

每个客户端都可以发起独立 WebSocket 连接。
FastAPI 异步事件循环 + SessionCache 使它们并发运行互不干扰。


🧠 四、Session 管理机制说明

字段 含义
session_id 每次连接唯一标识(UUID)
embeds 临时缓存该会话的所有 audio embeddings
prompt 当前会话的提示词
lock asyncio 锁,防止同一会话并发写冲突

并发安全策略:

  • 不同 WebSocket 连接 → 独立 session(并发执行)

  • 同一 session 的音频块按序加锁(顺序处理)


🔄 五、扩展点

你可以进一步增强:

方向 说明
KV Cache 支持 language_model.generate() 改为增量 forward() 模式,保持上下文缓存
Session 超时清理 定期扫描并清理闲置 session(防止内存泄露)
持久化存储 embeds 临时写入磁盘或 Redis,支持超长语音
多 GPU 并行 使用 vLLM 的 AsyncLLMEngine 替代 transformers

✅ 六、结果展示(示例日志)

[Init] Loading model from Qwen/Qwen2-Audio-7B-Instruct ...
[Session] Created: 31f8c9b7-12a2-4a15-9a17-343e3f9ad6c5
[Session] Created: 122fcb18-76b2-4f9a-9d7a-90a3289ac8cc
[WS] Client connected
[WS] Client connected
[Ack] {"status":"ok","session_id":"31f8...","chunk_len":80000}
[Ack] {"status":"ok","session_id":"122f...","chunk_len":80000}
[Final] {"status":"done","text":"this is qwen audio transcription"}
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐