大模型-qwen+audio的vllm部署初探-4
现在升级为。
·
现在升级为 支持多客户端并发流式推理的 Qwen2Audio 实时语音 LLM 服务。
🧠 改造目标
在上一个版本基础上,增加:
-
✅ Session 管理机制:
每个 WebSocket 连接(或请求)对应一个唯一 Session ID,用于区分不同客户端的缓存、上下文。 -
✅ 并发安全缓存(Session Cache):
每个 Session 独立维护自己的:-
audio_embeds累积 -
prompt(可定制) -
model_kv_cache(预留,可扩展)
-
-
✅ 并发处理能力(多用户同时推送流式音频)
利用 FastAPI 的异步并发机制 + asyncio 锁控制。
🚀 一、服务端完整实现(server_vllm_audio_streaming_session.py)
import torch
import torchaudio
import asyncio
import json
import uuid
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from transformers import Qwen2AudioProcessor, Qwen2AudioEncoder, Qwen2ForCausalLM
# ============================================================
# 初始化 FastAPI 应用
# ============================================================
app = FastAPI(title="Qwen2Audio Streaming LLM Server (Multi-Session)")
# ============================================================
# 模型加载
# ============================================================
MODEL_PATH = "Qwen/Qwen2-Audio-7B-Instruct"
print(f"[Init] Loading model from {MODEL_PATH} ...")
processor = Qwen2AudioProcessor.from_pretrained(MODEL_PATH)
audio_tower = Qwen2AudioEncoder.from_pretrained(MODEL_PATH)
language_model = Qwen2ForCausalLM.from_pretrained(
MODEL_PATH, torch_dtype=torch.float16, device_map="auto"
)
# 简单 projector
class AudioProjector(torch.nn.Module):
def __init__(self, audio_dim, text_dim):
super().__init__()
self.linear = torch.nn.Linear(audio_dim, text_dim)
def forward(self, x):
return self.linear(x)
projector = AudioProjector(audio_tower.config.hidden_size, language_model.config.hidden_size).to("cuda")
# ============================================================
# 会话缓存结构
# ============================================================
class SessionCache:
"""每个客户端 WebSocket 连接对应的状态缓存"""
def __init__(self):
self.sessions = {} # {session_id: {"embeds": [], "prompt": str, "lock": asyncio.Lock()}}
def create(self, prompt: str):
sid = str(uuid.uuid4())
self.sessions[sid] = {
"embeds": [],
"prompt": prompt,
"lock": asyncio.Lock(),
}
print(f"[Session] Created: {sid}")
return sid
def get(self, sid: str):
return self.sessions.get(sid)
def remove(self, sid: str):
if sid in self.sessions:
del self.sessions[sid]
print(f"[Session] Removed: {sid}")
session_cache = SessionCache()
# ============================================================
# HTTP 整段推理(非流式)
# ============================================================
@app.post("/infer")
async def infer_audio(file_path: str, prompt: str = "Transcribe this: <|AUDIO|>"):
waveform, sr = torchaudio.load(file_path)
inputs = processor(audios=waveform[0], sampling_rate=sr, return_tensors="pt").to("cuda")
features = audio_tower(**inputs).last_hidden_state
embeds = projector(features)
input_ids = language_model.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
outputs = language_model.generate(inputs_embeds=embeds, max_new_tokens=64)
text = language_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
return JSONResponse({"text": text})
# ============================================================
# WebSocket 流式接口(支持多会话)
# ============================================================
@app.websocket("/stream")
async def stream_audio(websocket: WebSocket):
await websocket.accept()
sid = None
print("[WS] Client connected")
try:
while True:
msg = await websocket.receive_text()
data = json.loads(msg)
# --- 1️⃣ 创建新会话 ---
if data["type"] == "start":
prompt = data.get("prompt", "Transcribe this: <|AUDIO|>")
sid = session_cache.create(prompt)
await websocket.send_text(json.dumps({"status": "started", "session_id": sid}))
# --- 2️⃣ 接收音频块 ---
elif data["type"] == "chunk" and sid:
sess = session_cache.get(sid)
if sess is None:
await websocket.send_text(json.dumps({"error": "Invalid session"}))
continue
chunk = torch.tensor(data["pcm"], dtype=torch.float32).to("cuda")
async with sess["lock"]:
inputs = processor(audios=chunk, sampling_rate=16000, return_tensors="pt").to("cuda")
features = audio_tower(**inputs).last_hidden_state
projected = projector(features)
sess["embeds"].append(projected)
await websocket.send_text(json.dumps({"status": "ok", "session_id": sid, "chunk_len": len(chunk)}))
# --- 3️⃣ 结束并生成结果 ---
elif data["type"] == "end" and sid:
sess = session_cache.get(sid)
if sess is None:
await websocket.send_text(json.dumps({"error": "Invalid session"}))
continue
async with sess["lock"]:
all_embeds = torch.cat(sess["embeds"], dim=1)
sess["embeds"].clear()
input_ids = language_model.tokenizer(sess["prompt"], return_tensors="pt").input_ids.to("cuda")
outputs = language_model.generate(inputs_embeds=all_embeds, max_new_tokens=64)
text = language_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
await websocket.send_text(json.dumps({"status": "done", "session_id": sid, "text": text}))
session_cache.remove(sid)
await websocket.close()
break
except WebSocketDisconnect:
print(f"[WS] Disconnected session {sid}")
if sid:
session_cache.remove(sid)
except Exception as e:
print("[Error]", e)
if sid:
session_cache.remove(sid)
await websocket.close()
# 启动命令:
# uvicorn server_vllm_audio_streaming_session:app --host 0.0.0.0 --port 8080
🧩 二、客户端示例(支持多用户同时连接)
import asyncio
import websockets
import torchaudio
import json
async def stream_audio(file_path):
uri = "ws://127.0.0.1:8080/stream"
async with websockets.connect(uri) as ws:
print("[Client] Connected")
# 1️⃣ 启动新会话
await ws.send(json.dumps({"type": "start", "prompt": "Transcribe this: <|AUDIO|>"}))
resp = await ws.recv()
start_info = json.loads(resp)
sid = start_info["session_id"]
print("[Client] Session:", sid)
# 2️⃣ 发送音频块
waveform, sr = torchaudio.load(file_path)
waveform = waveform[0]
chunk_size = sr * 5 # 5秒块
for i in range(0, len(waveform), chunk_size):
chunk = waveform[i:i + chunk_size].tolist()
await ws.send(json.dumps({"type": "chunk", "session_id": sid, "pcm": chunk}))
resp = await ws.recv()
print("[Ack]", resp)
# 3️⃣ 发送结束
await ws.send(json.dumps({"type": "end", "session_id": sid}))
result = await ws.recv()
print("[Final]", result)
async def main():
await asyncio.gather(
stream_audio("sample1.wav"),
stream_audio("sample2.wav"),
)
asyncio.run(main())
⚙️ 三、运行流程
🧱 1️⃣ 启动服务端
uvicorn server_vllm_audio_streaming_session:app --host 0.0.0.0 --port 8080
🧩 2️⃣ 启动多个客户端
每个客户端都可以发起独立 WebSocket 连接。
FastAPI 异步事件循环 + SessionCache 使它们并发运行互不干扰。
🧠 四、Session 管理机制说明
| 字段 | 含义 |
|---|---|
session_id |
每次连接唯一标识(UUID) |
embeds |
临时缓存该会话的所有 audio embeddings |
prompt |
当前会话的提示词 |
lock |
asyncio 锁,防止同一会话并发写冲突 |
并发安全策略:
-
不同 WebSocket 连接 → 独立 session(并发执行)
-
同一 session 的音频块按序加锁(顺序处理)
🔄 五、扩展点
你可以进一步增强:
| 方向 | 说明 |
|---|---|
| ✅ KV Cache 支持 | 将 language_model.generate() 改为增量 forward() 模式,保持上下文缓存 |
| ✅ Session 超时清理 | 定期扫描并清理闲置 session(防止内存泄露) |
| ✅ 持久化存储 | 将 embeds 临时写入磁盘或 Redis,支持超长语音 |
| ✅ 多 GPU 并行 | 使用 vLLM 的 AsyncLLMEngine 替代 transformers |
✅ 六、结果展示(示例日志)
[Init] Loading model from Qwen/Qwen2-Audio-7B-Instruct ...
[Session] Created: 31f8c9b7-12a2-4a15-9a17-343e3f9ad6c5
[Session] Created: 122fcb18-76b2-4f9a-9d7a-90a3289ac8cc
[WS] Client connected
[WS] Client connected
[Ack] {"status":"ok","session_id":"31f8...","chunk_len":80000}
[Ack] {"status":"ok","session_id":"122f...","chunk_len":80000}
[Final] {"status":"done","text":"this is qwen audio transcription"}
更多推荐
所有评论(0)