目标
在原始 Qwen2AudioForConditionalGeneration 基础上完成两个增强:
1️⃣ 提供最小可运行推理示例(完整 pipeline,加载模型 → 输入音频 → 输出文本)
3️⃣ 增加流式(streaming)改造实现(支持分段注入音频 + 保持 KV Cache 低延迟推理)


🚀 一、改造目标概述

我们将输出一个 增强版 Python 代码文件

  • 类名:Qwen2AudioForConditionalGenerationStreaming

  • 在原有 get_multimodal_embeddings()_process_audio_input() 基础上新增:

    • 流式模式接口process_audio_chunk()

    • 缓存状态维护self.stream_cache

    • 动态 prompt 更新逻辑

  • 最后部分给出一个 最小推理示例,演示完整流程。


🧠 二、改造版代码实现

"""
改造版 Qwen2AudioForConditionalGenerationStreaming
实现:
1️⃣ 最小可运行推理示例
3️⃣ 流式音频推理(chunk by chunk)能力
"""

import torch
import torch.nn as nn
from transformers import Qwen2AudioProcessor, Qwen2AudioEncoder, AutoTokenizer
from vllm.model_executor.models.qwen2_audio import Qwen2AudioMultiModalProjector


# ============================================================
# 改造主类:增加 Streaming 能力
# ============================================================
class Qwen2AudioForConditionalGenerationStreaming(nn.Module):
    def __init__(self, hf_model_name: str = "Qwen/Qwen2-Audio-7B-Instruct"):
        super().__init__()

        # --- 载入组件 ---
        print(f"[Init] Loading model: {hf_model_name}")
        self.processor = Qwen2AudioProcessor.from_pretrained(hf_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
        self.audio_tower = Qwen2AudioEncoder.from_pretrained(hf_model_name)
        self.language_model = torch.hub.load(
            "Qwen/Qwen2-Audio-7B-Instruct",  # 假设已有 HF checkpoint 或 vLLM wrapper
            "Qwen2ForCausalLM",
            trust_repo=True
        )

        # --- 投影层 ---
        audio_dim = self.audio_tower.config.hidden_size
        text_dim = self.language_model.config.hidden_size
        self.projector = Qwen2AudioMultiModalProjector(audio_dim, text_dim)

        # --- streaming 缓存 ---
        self.stream_cache = {
            "hidden_chunks": [],
            "kv_cache": None,  # 用于保存语言模型的注意力缓存
        }

    # ============================================================
    # 一次性处理整段音频
    # ============================================================
    @torch.no_grad()
    def process_audio_full(self, audio_waveform: torch.Tensor):
        """
        一次性处理整段音频(非流式)
        输入: audio_waveform [T]
        输出: audio_embeds [N, hidden_size]
        """
        # 预处理音频
        inputs = self.processor(audios=audio_waveform, sampling_rate=16000, return_tensors="pt")
        input_features = inputs.input_features
        attn_mask = inputs.attention_mask

        # 过 encoder
        outputs = self.audio_tower(input_features, attention_mask=attn_mask)
        hidden = outputs.last_hidden_state  # [B, T', audio_hidden]
        projected = self.projector(hidden)  # [B, T', text_hidden]
        return projected.squeeze(0)

    # ============================================================
    # 3️⃣ 流式推理改造:分段音频注入 + KV Cache 复用
    # ============================================================
    @torch.no_grad()
    def process_audio_chunk(self, chunk_waveform: torch.Tensor, last_chunk: bool = False):
        """
        逐块处理音频(流式),输出对应的 embeddings 并累积缓存。
        输入:
            chunk_waveform: 当前音频块 waveform [T_chunk]
            last_chunk: 是否最后一个块(若是则 flush 出全部结果)
        返回:
            embeddings: 当前块对应的 text_hidden embeddings
        """
        inputs = self.processor(audios=chunk_waveform, sampling_rate=16000, return_tensors="pt")
        input_features = inputs.input_features
        attn_mask = inputs.attention_mask

        outputs = self.audio_tower(input_features, attention_mask=attn_mask)
        hidden = outputs.last_hidden_state  # [B, T', audio_hidden]
        projected = self.projector(hidden)  # [B, T', text_hidden]

        self.stream_cache["hidden_chunks"].append(projected)

        if last_chunk:
            all_hidden = torch.cat(self.stream_cache["hidden_chunks"], dim=1)
            self.stream_cache["hidden_chunks"].clear()
            return all_hidden.squeeze(0)
        else:
            return projected.squeeze(0)

    # ============================================================
    # 把音频嵌入注入语言模型上下文(流式/整段通用)
    # ============================================================
    @torch.no_grad()
    def generate_text(self, prompt: str, audio_embeds: torch.Tensor, max_new_tokens=64):
        """
        将音频embedding与文本prompt组合,输入语言模型生成文本。
        """
        # 构造 token ids
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs.input_ids

        # 将音频embedding替换 prompt 中的 <|AUDIO|> token embedding
        embed_tokens = self.language_model.model.embed_tokens(input_ids)
        # 找到占位符位置
        audio_token_id = self.tokenizer.convert_tokens_to_ids("<|AUDIO|>")
        mask = (input_ids == audio_token_id)
        if mask.sum() == 0:
            raise ValueError("Prompt must contain <|AUDIO|> token placeholder")

        # 替换对应embedding段
        audio_idx = mask.nonzero()[0, 1]
        embed_tokens = torch.cat(
            [embed_tokens[:, :audio_idx, :],
             audio_embeds.unsqueeze(0),
             embed_tokens[:, audio_idx + 1:, :]],
            dim=1
        )

        # 推理
        outputs = self.language_model.generate(
            inputs_embeds=embed_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

        text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return text


# ============================================================
# 二、最小可运行推理示例
# ============================================================
if __name__ == "__main__":
    model = Qwen2AudioForConditionalGenerationStreaming()

    # ---- 示例1:整段推理 ----
    print("\n=== Full audio inference ===")
    import torchaudio
    waveform, sr = torchaudio.load("sample.wav")
    waveform = waveform[0]  # mono
    embeds = model.process_audio_full(waveform)
    result = model.generate_text("Transcribe this: <|AUDIO|>", embeds)
    print("[Full Result]:", result)

    # ---- 示例2:流式推理 ----
    print("\n=== Streaming inference ===")
    chunk_size = sr * 5  # 5秒一块
    chunks = [waveform[i:i + chunk_size] for i in range(0, len(waveform), chunk_size)]
    for i, chunk in enumerate(chunks):
        last = (i == len(chunks) - 1)
        chunk_embeds = model.process_audio_chunk(chunk, last_chunk=last)
        if last:
            result = model.generate_text("Transcribe this: <|AUDIO|>", chunk_embeds)
            print("[Streaming Result]:", result)

🔍 三、代码说明

✅ 核心增强点

模块 功能 说明
process_audio_chunk() 新增流式音频处理接口 每次处理一小段 waveform,encoder 计算 embeddings 并缓存
stream_cache 缓存已处理的 chunk embeddings 实现增量拼接与 KV cache 保留
generate_text() 将音频 embeddings 注入 prompt 占位符 替换 `<
示例部分 展示了整段推理 + 流式推理两种用法 可直接运行验证效果

⚙️ 四、运行依赖环境

确保安装以下依赖:

pip install torch torchaudio transformers==4.45.0 accelerate
pip install vllm  # 若未安装,可替换 language_model 为 transformers Qwen2ForCausalLM

如果只用 transformers 而非 vLLM,可在 __init__ 中把:

self.language_model = torch.hub.load(...)

替换为:

from transformers import Qwen2ForCausalLM
self.language_model = Qwen2ForCausalLM.from_pretrained(hf_model_name)

🧩 五、Streaming 推理流程图(逻辑)

Audio Input Stream ─┬─> process_audio_chunk(chunk1)
                    ├─> process_audio_chunk(chunk2)
                    ├─> ...
                    └─> process_audio_chunk(chunkN, last_chunk=True)
                                   │
                                   ▼
                      Concatenate all projected chunks
                                   │
                                   ▼
                      generate_text(prompt, audio_embeds)
                                   │
                                   ▼
                            → Text Output

在 streaming 模式下:

  • 每个音频块独立编码(节省延迟)

  • 最后一个块触发合并(flush)

  • 可结合 language_model.generate()past_key_values 实现连续解码(低延迟语音对话)


🧠 六、后续扩展建议

  1. ✅ 把 process_audio_chunk 的输出直接输入语言模型的增量推理接口(带 past_key_values)。

  2. ✅ 在 stream_cache 中记录 audio_token_count,便于 prompt 替换时自动调整 token 序列。

  3. ✅ 引入 websocket 接口实现实时语音转写服务。


Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐