Voicebox是一个本地优先的语音克隆工作室,具有类似DAW的功能,用于专业的语音合成。将其视为本地、免费和开源的替代品,类似于ElevenLabs——在您的机器上下载模型、克隆语音和生成语音。
我下载了项目源代码,并根据项目要求安装并配置好bun、rust和python环境。当使用bun run dev:server和bun run tauri dev指令在两个分开的控制台运行项目时,在模型已经完全下载到本地缓存的情况下,项目服务端仍在访问HuggingFace下载模型文件。在Trae IDE中,花了近一个晚上的时间,经数个阶段,在Trae后端AI的帮助下,解决了qwen-tts模型离线加载的问题。下述内容,是由Trae整理的我在Trae中解决问题的全过程,现分享给朋友们。

一、Qwen-tts模型离线加载问题解决方案

问题背景

初始需求

项目在运行时,模型已经下载到本地缓存,但仍然需要访问 HuggingFace 网站才能运行。用户要求:

  1. 下载模型时禁用证书验证
  2. 下载模型时启用 HuggingFace 国内镜像
  3. 在模型已下载到本地的情况下,不再重复访问网络下载模型相关文件

环境信息

  • 操作系统:Windows
  • 模型缓存路径:C:\Users\{username}\.cache\huggingface\hub\
  • 涉及模型:Qwen3-TTS-12Hz-1.7B-Base、Qwen3-TTS-12Hz-0.6B-Base、Whisper

问题分析过程

第一阶段:补丁式方法(自下而上)

1.1 初步尝试

创建 hf_config.py 模块,尝试通过环境变量控制 HuggingFace Hub 行为:

# 设置离线模式
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

# 设置国内镜像
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# 禁用 SSL 验证
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

结果:部分有效,但 qwen_tts 库内部仍尝试网络访问

1.2 遇到的问题

问题一:FutureWarning 警告

FutureWarning: You are using `torch.load` with `weights_only=False`

解决:添加 weights_only=False 参数并抑制警告

问题二:SSL 警告

InsecureRequestWarning: Unverified HTTPS request

解决:添加 SSL 警告抑制

问题三:离线模式错误

Error: Cannot reach `https://hf-mirror.com/api/models/...`
offline mode is enabled.

原因qwen_tts 库需要网络验证,即使设置了离线模式

1.3 发现 qwen_tts 库的问题

查看 qwen_tts 库源码发现:

# qwen_tts/inference/qwen3_tts_model.py
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
# 问题:processor 加载没有传递 **kwargs

尝试解决:创建 qwen_tts_patch.py 补丁,确保 local_files_only=True 传递给 processor

1.4 发现 .no_exist 目录问题

.no_exist 目录记录了哪些文件在远程仓库中不存在:

models--Qwen--Qwen3-TTS-12Hz-1.7B-Base/
└── .no_exist/
    └── fd4b254389122332181a7c3db7f27e918eec64e3/
        ├── processor_config.json  (0 bytes)
        ├── tokenizer.json         (0 bytes)
        └── ...

尝试解决:添加 clean_no_exist_cache() 函数清理该目录

结果:仍然无法完全阻止网络访问


第二阶段:从原理到问题(自上而下)

2.1 分析 HuggingFace 缓存机制

通过命令检查缓存目录结构:

# 检查 blobs 目录
Get-ChildItem "...\models--Qwen--Qwen3-TTS-12Hz-1.7B-Base\blobs"
# 发现大文件:38fc7fc51c5e... (3678.72 MB) - 实际模型权重

# 检查 snapshots 目录
Get-ChildItem "...\snapshots\fd4b254389122332181a7c3db7f27e918eec64e3"
# 发现符号链接:
# config.json → ../../blobs/xxx
# model.safetensors → ../../blobs/xxx
2.2 理解缓存结构
HuggingFace 缓存结构:
models--Qwen--Qwen3-TTS-12Hz-1.7B-Base/
├── blobs/                          # 实际文件内容(按 SHA256 哈希命名)
│   ├── 38fc7fc51c5e...            # 3.6GB 模型权重
│   ├── 836b7b357f5e...            # 650MB 其他文件
│   └── ...
├── snapshots/                      # 符号链接目录(按 commit hash 命名)
│   └── fd4b254389122332...        # 快照目录
│       ├── config.json → ../../blobs/xxx      # 符号链接
│       ├── model.safetensors → ../../blobs/xxx
│       ├── preprocessor_config.json → ../../blobs/xxx
│       └── ...
├── refs/                           # 引用(main 等)
└── .no_exist/                      # 记录不存在的文件
2.3 关键洞察

发现

  • 当使用 HuggingFace Hub ID(如 Qwen/Qwen3-TTS-12Hz-1.7B-Base)加载时,from_pretrained() 会:

    1. 解析 Hub ID
    2. 检查本地缓存
    3. 尝试验证远程仓库状态(即使设置了 local_files_only=True
    4. 可能触发网络请求
  • 当使用本地路径(如 C:\Users\xxx\.cache\...\snapshots\xxx)加载时:

    1. 直接读取本地文件
    2. 完全绕过 HuggingFace Hub 的网络验证机制
2.4 最终解决方案

修改模型加载逻辑:

def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
    """
    获取本地快照路径(如果模型已完全缓存)。
    """
    try:
        from huggingface_hub import constants as hf_constants
        model_id = self._get_model_path(model_size)
        repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))

        if not repo_cache.exists():
            return None

        # 检查是否有 .incomplete 文件
        blobs_dir = repo_cache / "blobs"
        if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
            return None

        # 获取最新快照
        snapshots_dir = repo_cache / "snapshots"
        if not snapshots_dir.exists():
            return None

        snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
        if not snapshot_dirs:
            return None

        latest_snapshot = snapshot_dirs[0]

        # 检查模型权重是否存在
        has_weights = (
            any(latest_snapshot.rglob("*.safetensors")) or
            any(latest_snapshot.rglob("*.bin"))
        )
        if not has_weights:
            return None

        # 检查 config.json 是否存在
        if not (latest_snapshot / "config.json").exists():
            return None

        return latest_snapshot
    except Exception as e:
        print(f"[_get_local_snapshot_path] Error: {e}")
        return None

def _load_model_sync(self, model_size: str):
    """同步模型加载。"""
    # 获取本地快照路径
    local_snapshot_path = self._get_local_snapshot_path(model_size)
    is_cached = local_snapshot_path is not None

    # 获取 HuggingFace Hub ID(用于下载)
    model_id = self._get_model_path(model_size)

    # 确定加载路径
    # 如果已缓存,使用本地快照路径,避免任何网络访问
    # 如果未缓存,使用 HuggingFace Hub ID 进行下载
    load_path = str(local_snapshot_path) if is_cached else model_id

    if is_cached:
        print(f"Loading model {model_size} from local cache: {load_path}")
    else:
        print(f"Model {model_size} not cached, will download from HuggingFace Hub")
        setup_huggingface_for_online()

    # 加载模型
    self.model = Qwen3TTSModel.from_pretrained(load_path, ...)

解决方案对比

方法 优点 缺点 结果
设置 HF_HUB_OFFLINE=1 简单 某些库仍尝试网络验证 ❌ 失败
传递 local_files_only=True 标准做法 qwen_tts 库未传递给 processor ❌ 失败
qwen_tts 打补丁 针对性强 维护成本高,可能影响其他库 ⚠️ 部分有效
清理 .no_exist 目录 减少网络请求 无法完全阻止网络访问 ⚠️ 部分有效
直接使用本地快照路径 完全离线,零网络访问 需要检测缓存完整性 成功

最终实现

修改的文件

  1. backend/backends/pytorch_backend.py

    • 添加 _get_local_snapshot_path() 方法
    • 修改 _load_model_sync() 方法(TTS 和 STT)
  2. backend/backends/mlx_backend.py

    • 同样的修改应用于 MLX 后端
  3. backend/utils/hf_config.py

    • 保留国内镜像和 SSL 配置
    • 移除了不再需要的离线模式设置

删除的文件

  • backend/utils/qwen_tts_patch.py - 不再需要补丁

方法论总结

自下而上 vs 自上而下

方法 适用场景 本次对话中的体现
自下而上(就问题解决问题) 快速修复、临时方案、问题边界清晰 设置环境变量、打补丁、清理缓存
自上而下(从原理到问题) 复杂系统、根本性问题、需要深入理解 分析 HuggingFace 缓存机制、理解加载路径本质

本次对话的启示

  1. 理解系统原理是解决复杂问题的关键

    • HuggingFace 的缓存机制(blobs + snapshots + 符号链接)
    • from_pretrained() 方法的行为差异(Hub ID vs 本地路径)
  2. 补丁式方法的局限性

    • 只能解决表面问题
    • 可能引入新的复杂性
    • 难以维护
  3. 从原理出发的优势

    • 找到根本解决方案
    • 代码更简洁、更可靠
    • 更容易维护

最佳实践

模型离线加载检查清单

  1. 检查缓存完整性

    • blobs 目录存在且包含大文件
    • snapshots 目录存在且包含符号链接
    • 没有 .incomplete 文件
    • config.json 存在
  2. 使用本地路径加载

    # 推荐:使用本地快照路径
    load_path = str(local_snapshot_path)  # 如 C:\Users\xxx\.cache\...\snapshots\xxx
    
    # 不推荐:使用 Hub ID(可能触发网络访问)
    load_path = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
    
  3. 环境配置(可选)

    # 国内镜像(用于下载)
    os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
    
    # 禁用 SSL 验证(如有需要)
    import ssl
    ssl._create_default_https_context = ssl._create_unverified_context
    

参考资料


文档创建时间:2026-03-10
问题解决状态:已解决 ✅

二、根据上述过程最终修改生成的代码文件如下

1. backend/main.py

"""
FastAPI application for voicebox backend.

Handles voice cloning, generation history, and server mode.
"""

# 在最开始抑制警告
from .utils.warning_suppressor import suppress_common_warnings, suppress_ssl_warnings
suppress_common_warnings()
suppress_ssl_warnings()

from fastapi import FastAPI, Depends, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
import asyncio
import uvicorn
import argparse
import torch
import tempfile
import io
from pathlib import Path
import uuid
import asyncio
import signal
import os
from urllib.parse import quote


def _safe_content_disposition(disposition_type: str, filename: str) -> str:
    """Build a Content-Disposition header that is safe for non-ASCII filenames.

    Uses RFC 5987 ``filename*`` parameter so that browsers can decode
    UTF-8 filenames while the ``filename`` fallback stays ASCII-only.
    """
    ascii_name = "".join(
        c for c in filename if c.isascii() and (c.isalnum() or c in " -_.")
    ).strip() or "download"
    utf8_name = quote(filename, safe="")
    return (
        f'{disposition_type}; filename="{ascii_name}"; '
        f"filename*=UTF-8''{utf8_name}"
    )


from . import database, models, profiles, history, tts, transcribe, config, export_import, channels, stories, __version__
from .database import get_db, Generation as DBGeneration, VoiceProfile as DBVoiceProfile
from .utils.progress import get_progress_manager
from .utils.tasks import get_task_manager
from .utils.cache import clear_voice_prompt_cache
from .platform_detect import get_backend_type

app = FastAPI(
    title="voicebox API",
    description="Production-quality Qwen3-TTS voice cloning API",
    version=__version__,
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure appropriately for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# ============================================
# ROOT & HEALTH ENDPOINTS
# ============================================

@app.get("/")
async def root():
    """Root endpoint."""
    return {"message": "voicebox API", "version": __version__}


@app.post("/shutdown")
async def shutdown():
    """Gracefully shutdown the server."""
    async def shutdown_async():
        await asyncio.sleep(0.1)  # Give response time to send
        os.kill(os.getpid(), signal.SIGTERM)

    asyncio.create_task(shutdown_async())
    return {"message": "Shutting down..."}


@app.get("/health", response_model=models.HealthResponse)
async def health():
    """Health check endpoint."""
    from huggingface_hub import hf_hub_download, constants as hf_constants
    from pathlib import Path
    import os

    tts_model = tts.get_tts_model()
    backend_type = get_backend_type()

    # Check for GPU availability (CUDA, MPS, Intel Arc XPU, or DirectML)
    has_cuda = torch.cuda.is_available()
    has_mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()

    # Intel Arc / Intel Xe via intel-extension-for-pytorch (IPEX)
    has_xpu = False
    xpu_name = None
    try:
        import intel_extension_for_pytorch as ipex  # noqa: F401
        if hasattr(torch, 'xpu') and torch.xpu.is_available():
            has_xpu = True
            try:
                xpu_name = torch.xpu.get_device_name(0)
            except Exception:
                xpu_name = "Intel GPU"
    except ImportError:
        pass

    # DirectML backend (torch-directml) for any Windows GPU
    has_directml = False
    directml_name = None
    try:
        import torch_directml
        if torch_directml.device_count() > 0:
            has_directml = True
            try:
                directml_name = torch_directml.device_name(0)
            except Exception:
                directml_name = "DirectML GPU"
    except ImportError:
        pass

    gpu_available = has_cuda or has_mps or has_xpu or has_directml or backend_type == "mlx"

    gpu_type = None
    if has_cuda:
        gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})"
    elif has_mps:
        gpu_type = "MPS (Apple Silicon)"
    elif backend_type == "mlx":
        gpu_type = "Metal (Apple Silicon via MLX)"
    elif has_xpu:
        gpu_type = f"XPU ({xpu_name})"
    elif has_directml:
        gpu_type = f"DirectML ({directml_name})"

    vram_used = None
    if has_cuda:
        vram_used = torch.cuda.memory_allocated() / 1024 / 1024  # MB
    
    # Check if model is loaded - use the same logic as model status endpoint
    model_loaded = False
    model_size = None
    try:
        # Use the same check as model status endpoint
        if tts_model.is_loaded():
            model_loaded = True
            # Get the actual loaded model size
            # Check _current_model_size first (more reliable for actually loaded models)
            model_size = getattr(tts_model, '_current_model_size', None)
            if not model_size:
                # Fallback to model_size attribute (which should be set when model loads)
                model_size = getattr(tts_model, 'model_size', None)
    except Exception:
        # If there's an error checking, assume not loaded
        model_loaded = False
        model_size = None
    
    # Check if default model is downloaded (cached)
    model_downloaded = None
    try:
        # Check if the default model (1.7B) is cached
        # Use different model IDs based on backend
        if backend_type == "mlx":
            default_model_id = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
        else:
            default_model_id = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
        
        # Method 1: Try scan_cache_dir if available
        try:
            from huggingface_hub import scan_cache_dir
            cache_info = scan_cache_dir()
            for repo in cache_info.repos:
                if repo.repo_id == default_model_id:
                    model_downloaded = True
                    break
        except (ImportError, Exception):
            # Method 2: Check cache directory (using HuggingFace's OS-specific cache location)
            cache_dir = hf_constants.HF_HUB_CACHE
            repo_cache = Path(cache_dir) / ("models--" + default_model_id.replace("/", "--"))
            if repo_cache.exists():
                has_model_files = (
                    any(repo_cache.rglob("*.bin")) or
                    any(repo_cache.rglob("*.safetensors")) or
                    any(repo_cache.rglob("*.pt")) or
                    any(repo_cache.rglob("*.pth")) or
                    any(repo_cache.rglob("*.npz"))  # MLX models may use npz
                )
                model_downloaded = has_model_files
    except Exception:
        pass
    
    return models.HealthResponse(
        status="healthy",
        model_loaded=model_loaded,
        model_downloaded=model_downloaded,
        model_size=model_size,
        gpu_available=gpu_available,
        gpu_type=gpu_type,
        vram_used_mb=vram_used,
        backend_type=backend_type,
    )


# ============================================
# VOICE PROFILE ENDPOINTS
# ============================================

@app.post("/profiles", response_model=models.VoiceProfileResponse)
async def create_profile(
    data: models.VoiceProfileCreate,
    db: Session = Depends(get_db),
):
    """Create a new voice profile."""
    try:
        return await profiles.create_profile(data, db)
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/profiles", response_model=List[models.VoiceProfileResponse])
async def list_profiles(db: Session = Depends(get_db)):
    """List all voice profiles."""
    return await profiles.list_profiles(db)


@app.post("/profiles/import", response_model=models.VoiceProfileResponse)
async def import_profile(
    file: UploadFile = File(...),
    db: Session = Depends(get_db),
):
    """Import a voice profile from a ZIP archive."""
    # Validate file size (max 100MB)
    MAX_FILE_SIZE = 100 * 1024 * 1024  # 100MB
    
    # Read file content
    content = await file.read()
    
    if len(content) > MAX_FILE_SIZE:
        raise HTTPException(
            status_code=400,
            detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
        )
    
    try:
        profile = await export_import.import_profile_from_zip(content, db)
        return profile
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def get_profile(
    profile_id: str,
    db: Session = Depends(get_db),
):
    """Get a voice profile by ID."""
    profile = await profiles.get_profile(profile_id, db)
    if not profile:
        raise HTTPException(status_code=404, detail="Profile not found")
    return profile


@app.put("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def update_profile(
    profile_id: str,
    data: models.VoiceProfileCreate,
    db: Session = Depends(get_db),
):
    """Update a voice profile."""
    profile = await profiles.update_profile(profile_id, data, db)
    if not profile:
        raise HTTPException(status_code=404, detail="Profile not found")
    return profile


@app.delete("/profiles/{profile_id}")
async def delete_profile(
    profile_id: str,
    db: Session = Depends(get_db),
):
    """Delete a voice profile."""
    success = await profiles.delete_profile(profile_id, db)
    if not success:
        raise HTTPException(status_code=404, detail="Profile not found")
    return {"message": "Profile deleted successfully"}


@app.post("/profiles/{profile_id}/samples", response_model=models.ProfileSampleResponse)
async def add_profile_sample(
    profile_id: str,
    file: UploadFile = File(...),
    reference_text: str = Form(...),
    db: Session = Depends(get_db),
):
    """Add a sample to a voice profile."""
    # Preserve the uploaded file's extension so librosa can detect format correctly.
    # Defaulting to .wav was causing soundfile to reject MP3/WebM content as invalid WAV.
    _allowed_audio_exts = {'.wav', '.mp3', '.m4a', '.ogg', '.flac', '.aac', '.webm', '.opus'}
    _uploaded_ext = Path(file.filename or '').suffix.lower()
    file_suffix = _uploaded_ext if _uploaded_ext in _allowed_audio_exts else '.wav'

    with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp:
        content = await file.read()
        tmp.write(content)
        tmp_path = tmp.name

    try:
        sample = await profiles.add_profile_sample(
            profile_id,
            tmp_path,
            reference_text,
            db,
        )
        return sample
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to process audio file: {str(e)}")
    finally:
        # Clean up temp file
        Path(tmp_path).unlink(missing_ok=True)


@app.get("/profiles/{profile_id}/samples", response_model=List[models.ProfileSampleResponse])
async def get_profile_samples(
    profile_id: str,
    db: Session = Depends(get_db),
):
    """Get all samples for a profile."""
    return await profiles.get_profile_samples(profile_id, db)


@app.delete("/profiles/samples/{sample_id}")
async def delete_profile_sample(
    sample_id: str,
    db: Session = Depends(get_db),
):
    """Delete a profile sample."""
    success = await profiles.delete_profile_sample(sample_id, db)
    if not success:
        raise HTTPException(status_code=404, detail="Sample not found")
    return {"message": "Sample deleted successfully"}


@app.put("/profiles/samples/{sample_id}", response_model=models.ProfileSampleResponse)
async def update_profile_sample(
    sample_id: str,
    data: models.ProfileSampleUpdate,
    db: Session = Depends(get_db),
):
    """Update a profile sample's reference text."""
    sample = await profiles.update_profile_sample(sample_id, data.reference_text, db)
    if not sample:
        raise HTTPException(status_code=404, detail="Sample not found")
    return sample


@app.post("/profiles/{profile_id}/avatar", response_model=models.VoiceProfileResponse)
async def upload_profile_avatar(
    profile_id: str,
    file: UploadFile = File(...),
    db: Session = Depends(get_db),
):
    """Upload or update avatar image for a profile."""
    # Save uploaded file to temp location
    with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
        content = await file.read()
        tmp.write(content)
        tmp_path = tmp.name

    try:
        profile = await profiles.upload_avatar(profile_id, tmp_path, db)
        return profile
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    finally:
        # Clean up temp file
        Path(tmp_path).unlink(missing_ok=True)


@app.get("/profiles/{profile_id}/avatar")
async def get_profile_avatar(
    profile_id: str,
    db: Session = Depends(get_db),
):
    """Get avatar image for a profile."""
    profile = await profiles.get_profile(profile_id, db)
    if not profile:
        raise HTTPException(status_code=404, detail="Profile not found")

    if not profile.avatar_path:
        raise HTTPException(status_code=404, detail="No avatar found for this profile")

    avatar_path = Path(profile.avatar_path)
    if not avatar_path.exists():
        raise HTTPException(status_code=404, detail="Avatar file not found")

    return FileResponse(avatar_path)


@app.delete("/profiles/{profile_id}/avatar")
async def delete_profile_avatar(
    profile_id: str,
    db: Session = Depends(get_db),
):
    """Delete avatar image for a profile."""
    success = await profiles.delete_avatar(profile_id, db)
    if not success:
        raise HTTPException(status_code=404, detail="Profile not found or no avatar to delete")
    return {"message": "Avatar deleted successfully"}


@app.get("/profiles/{profile_id}/export")
async def export_profile(
    profile_id: str,
    db: Session = Depends(get_db),
):
    """Export a voice profile as a ZIP archive."""
    try:
        # Get profile to get name for filename
        profile = await profiles.get_profile(profile_id, db)
        if not profile:
            raise HTTPException(status_code=404, detail="Profile not found")
        
        # Export to ZIP
        zip_bytes = export_import.export_profile_to_zip(profile_id, db)
        
        # Create safe filename
        safe_name = "".join(c for c in profile.name if c.isalnum() or c in (' ', '-', '_')).strip()
        if not safe_name:
            safe_name = "profile"
        filename = f"profile-{safe_name}.voicebox.zip"
        
        # Return as streaming response
        return StreamingResponse(
            io.BytesIO(zip_bytes),
            media_type="application/zip",
            headers={
                "Content-Disposition": _safe_content_disposition("attachment", filename)
            }
        )
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


# ============================================
# AUDIO CHANNEL ENDPOINTS
# ============================================

@app.get("/channels", response_model=List[models.AudioChannelResponse])
async def list_channels(db: Session = Depends(get_db)):
    """List all audio channels."""
    return await channels.list_channels(db)


@app.post("/channels", response_model=models.AudioChannelResponse)
async def create_channel(
    data: models.AudioChannelCreate,
    db: Session = Depends(get_db),
):
    """Create a new audio channel."""
    try:
        return await channels.create_channel(data, db)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def get_channel(
    channel_id: str,
    db: Session = Depends(get_db),
):
    """Get an audio channel by ID."""
    channel = await channels.get_channel(channel_id, db)
    if not channel:
        raise HTTPException(status_code=404, detail="Channel not found")
    return channel


@app.put("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def update_channel(
    channel_id: str,
    data: models.AudioChannelUpdate,
    db: Session = Depends(get_db),
):
    """Update an audio channel."""
    try:
        channel = await channels.update_channel(channel_id, data, db)
        if not channel:
            raise HTTPException(status_code=404, detail="Channel not found")
        return channel
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.delete("/channels/{channel_id}")
async def delete_channel(
    channel_id: str,
    db: Session = Depends(get_db),
):
    """Delete an audio channel."""
    try:
        success = await channels.delete_channel(channel_id, db)
        if not success:
            raise HTTPException(status_code=404, detail="Channel not found")
        return {"message": "Channel deleted successfully"}
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/channels/{channel_id}/voices")
async def get_channel_voices(
    channel_id: str,
    db: Session = Depends(get_db),
):
    """Get list of profile IDs assigned to a channel."""
    try:
        profile_ids = await channels.get_channel_voices(channel_id, db)
        return {"profile_ids": profile_ids}
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.put("/channels/{channel_id}/voices")
async def set_channel_voices(
    channel_id: str,
    data: models.ChannelVoiceAssignment,
    db: Session = Depends(get_db),
):
    """Set which voices are assigned to a channel."""
    try:
        await channels.set_channel_voices(channel_id, data, db)
        return {"message": "Channel voices updated successfully"}
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/profiles/{profile_id}/channels")
async def get_profile_channels(
    profile_id: str,
    db: Session = Depends(get_db),
):
    """Get list of channel IDs assigned to a profile."""
    try:
        channel_ids = await channels.get_profile_channels(profile_id, db)
        return {"channel_ids": channel_ids}
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.put("/profiles/{profile_id}/channels")
async def set_profile_channels(
    profile_id: str,
    data: models.ProfileChannelAssignment,
    db: Session = Depends(get_db),
):
    """Set which channels a profile is assigned to."""
    try:
        await channels.set_profile_channels(profile_id, data, db)
        return {"message": "Profile channels updated successfully"}
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


# ============================================
# GENERATION ENDPOINTS
# ============================================

@app.post("/generate", response_model=models.GenerationResponse)
async def generate_speech(
    data: models.GenerationRequest,
    db: Session = Depends(get_db),
):
    """Generate speech from text using a voice profile."""
    task_manager = get_task_manager()
    generation_id = str(uuid.uuid4())
    
    try:
        # Start tracking generation
        task_manager.start_generation(
            task_id=generation_id,
            profile_id=data.profile_id,
            text=data.text,
        )
        
        # Get profile
        profile = await profiles.get_profile(data.profile_id, db)
        if not profile:
            raise HTTPException(status_code=404, detail="Profile not found")
        
        # Generate audio

        # Resolve model size and load the correct model FIRST.
        # This must happen before create_voice_prompt_for_profile because that
        # function calls load_model_async(None), which falls back to self.model_size.
        # If the model is already loaded with the right size at that point, it
        # returns immediately and the voice prompt is created by the correct model.
        tts_model = tts.get_tts_model()
        model_size = data.model_size or "1.7B"

        # Check if model needs to be downloaded first
        model_path = tts_model._get_model_path(model_size)
        if not tts_model._is_model_cached(model_size):
            # Model is not fully cached — kick off a background download and tell
            # the client to retry once it's ready.
            model_name = f"qwen-tts-{model_size}"

            async def download_model_background():
                try:
                    await tts_model.load_model_async(model_size)
                except Exception as e:
                    task_manager.error_download(model_name, str(e))

            task_manager.start_download(model_name)
            asyncio.create_task(download_model_background())

            raise HTTPException(
                status_code=202,
                detail={
                    "message": f"Model {model_size} is being downloaded. Please wait and try again.",
                    "model_name": model_name,
                    "downloading": True,
                },
            )

        # Load (or switch to) the requested model before building the voice prompt
        await tts_model.load_model_async(model_size)

        # Create voice prompt from profile (model is already loaded with correct size)
        voice_prompt = await profiles.create_voice_prompt_for_profile(
            data.profile_id,
            db,
        )

        audio, sample_rate = await tts_model.generate(
            data.text,
            voice_prompt,
            data.language,
            data.seed,
            data.instruct,
        )

        # Calculate duration
        duration = len(audio) / sample_rate

        # Save audio
        audio_path = config.get_generations_dir() / f"{generation_id}.wav"

        from .utils.audio import save_audio
        save_audio(audio, str(audio_path), sample_rate)

        # Create history entry
        generation = await history.create_generation(
            profile_id=data.profile_id,
            text=data.text,
            language=data.language,
            audio_path=str(audio_path),
            duration=duration,
            seed=data.seed,
            db=db,
            instruct=data.instruct,
        )
        
        # Mark generation as complete
        task_manager.complete_generation(generation_id)
        
        return generation
        
    except ValueError as e:
        task_manager.complete_generation(generation_id)
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        task_manager.complete_generation(generation_id)
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/generate/stream")
async def stream_speech(
    data: models.GenerationRequest,
    db: Session = Depends(get_db),
):
    """
    Generate speech and stream the WAV audio directly without saving to disk.

    Returns raw WAV bytes via a StreamingResponse so the client can start
    playing audio before the entire file has been received.  This endpoint
    does NOT create a history entry — use /generate for that.
    """
    profile = await profiles.get_profile(data.profile_id, db)
    if not profile:
        raise HTTPException(status_code=404, detail="Profile not found")

    tts_model = tts.get_tts_model()
    model_size = data.model_size or "1.7B"

    if not tts_model._is_model_cached(model_size):
        raise HTTPException(
            status_code=400,
            detail=f"Model {model_size} is not downloaded yet. Use /generate to trigger a download.",
        )

    # Load the correct model before building the voice prompt (fixes issue #96)
    await tts_model.load_model_async(model_size)

    voice_prompt = await profiles.create_voice_prompt_for_profile(data.profile_id, db)

    audio, sample_rate = await tts_model.generate(
        data.text,
        voice_prompt,
        data.language,
        data.seed,
        data.instruct,
    )

    wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate)

    async def _wav_stream():
        # Yield in chunks so large responses don't block the event loop
        chunk_size = 64 * 1024  # 64 KB
        for i in range(0, len(wav_bytes), chunk_size):
            yield wav_bytes[i : i + chunk_size]

    return StreamingResponse(
        _wav_stream(),
        media_type="audio/wav",
        headers={"Content-Disposition": 'attachment; filename="speech.wav"'},
    )


# ============================================
# HISTORY ENDPOINTS
# ============================================

@app.get("/history", response_model=models.HistoryListResponse)
async def list_history(
    profile_id: Optional[str] = None,
    search: Optional[str] = None,
    limit: int = 50,
    offset: int = 0,
    db: Session = Depends(get_db),
):
    """List generation history with optional filters."""
    query = models.HistoryQuery(
        profile_id=profile_id,
        search=search,
        limit=limit,
        offset=offset,
    )
    return await history.list_generations(query, db)


@app.get("/history/stats")
async def get_stats(db: Session = Depends(get_db)):
    """Get generation statistics."""
    return await history.get_generation_stats(db)


@app.post("/history/import")
async def import_generation(
    file: UploadFile = File(...),
    db: Session = Depends(get_db),
):
    """Import a generation from a ZIP archive."""
    # Validate file size (max 50MB)
    MAX_FILE_SIZE = 50 * 1024 * 1024  # 50MB
    
    # Read file content
    content = await file.read()
    
    if len(content) > MAX_FILE_SIZE:
        raise HTTPException(
            status_code=400,
            detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
        )
    
    try:
        result = await export_import.import_generation_from_zip(content, db)
        return result
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/history/{generation_id}", response_model=models.HistoryResponse)
async def get_generation(
    generation_id: str,
    db: Session = Depends(get_db),
):
    """Get a generation by ID."""
    # Get generation with profile name
    result = db.query(
        DBGeneration,
        DBVoiceProfile.name.label('profile_name')
    ).join(
        DBVoiceProfile,
        DBGeneration.profile_id == DBVoiceProfile.id
    ).filter(
        DBGeneration.id == generation_id
    ).first()
    
    if not result:
        raise HTTPException(status_code=404, detail="Generation not found")
    
    gen, profile_name = result
    return models.HistoryResponse(
        id=gen.id,
        profile_id=gen.profile_id,
        profile_name=profile_name,
        text=gen.text,
        language=gen.language,
        audio_path=gen.audio_path,
        duration=gen.duration,
        seed=gen.seed,
        instruct=gen.instruct,
        created_at=gen.created_at,
    )


@app.delete("/history/{generation_id}")
async def delete_generation(
    generation_id: str,
    db: Session = Depends(get_db),
):
    """Delete a generation."""
    success = await history.delete_generation(generation_id, db)
    if not success:
        raise HTTPException(status_code=404, detail="Generation not found")
    return {"message": "Generation deleted successfully"}


@app.get("/history/{generation_id}/export")
async def export_generation(
    generation_id: str,
    db: Session = Depends(get_db),
):
    """Export a generation as a ZIP archive."""
    try:
        # Get generation to create filename
        generation = db.query(DBGeneration).filter_by(id=generation_id).first()
        if not generation:
            raise HTTPException(status_code=404, detail="Generation not found")
        
        # Export to ZIP
        zip_bytes = export_import.export_generation_to_zip(generation_id, db)
        
        # Create safe filename from text
        safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
        if not safe_text:
            safe_text = "generation"
        filename = f"generation-{safe_text}.voicebox.zip"
        
        # Return as streaming response
        return StreamingResponse(
            io.BytesIO(zip_bytes),
            media_type="application/zip",
            headers={
                "Content-Disposition": _safe_content_disposition("attachment", filename)
            }
        )
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/history/{generation_id}/export-audio")
async def export_generation_audio(
    generation_id: str,
    db: Session = Depends(get_db),
):
    """Export only the audio file from a generation."""
    generation = db.query(DBGeneration).filter_by(id=generation_id).first()
    if not generation:
        raise HTTPException(status_code=404, detail="Generation not found")
    
    audio_path = Path(generation.audio_path)
    if not audio_path.exists():
        raise HTTPException(status_code=404, detail="Audio file not found")
    
    # Create safe filename from text
    safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
    if not safe_text:
        safe_text = "generation"
    filename = f"{safe_text}.wav"
    
    return FileResponse(
        audio_path,
        media_type="audio/wav",
        headers={
            "Content-Disposition": _safe_content_disposition("attachment", filename)
        }
    )


# ============================================
# TRANSCRIPTION ENDPOINTS
# ============================================

@app.post("/transcribe", response_model=models.TranscriptionResponse)
async def transcribe_audio(
    file: UploadFile = File(...),
    language: Optional[str] = Form(None),
):
    """Transcribe audio file to text."""
    # Save uploaded file to temporary location
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        content = await file.read()
        tmp.write(content)
        tmp_path = tmp.name
    
    try:
        # Get audio duration
        from .utils.audio import load_audio
        audio, sr = load_audio(tmp_path)
        duration = len(audio) / sr
        
        # Transcribe
        whisper_model = transcribe.get_whisper_model()

        # Check if Whisper model is downloaded (uses default size "base")
        model_size = whisper_model.model_size
        model_name = f"openai/whisper-{model_size}"

        # Check if model is cached
        from huggingface_hub import constants as hf_constants
        repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--"))
        if not repo_cache.exists():
            # Start download in background
            progress_model_name = f"whisper-{model_size}"

            async def download_whisper_background():
                try:
                    await whisper_model.load_model_async(model_size)
                except Exception as e:
                    get_task_manager().error_download(progress_model_name, str(e))

            get_task_manager().start_download(progress_model_name)
            asyncio.create_task(download_whisper_background())

            # Return 202 Accepted
            raise HTTPException(
                status_code=202,
                detail={
                    "message": f"Whisper model {model_size} is being downloaded. Please wait and try again.",
                    "model_name": progress_model_name,
                    "downloading": True
                }
            )

        text = await whisper_model.transcribe(tmp_path, language)
        
        return models.TranscriptionResponse(
            text=text,
            duration=duration,
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        # Clean up temp file
        Path(tmp_path).unlink(missing_ok=True)


# ============================================
# STORY ENDPOINTS
# ============================================

@app.get("/stories", response_model=List[models.StoryResponse])
async def list_stories(db: Session = Depends(get_db)):
    """List all stories."""
    return await stories.list_stories(db)


@app.post("/stories", response_model=models.StoryResponse)
async def create_story(
    data: models.StoryCreate,
    db: Session = Depends(get_db),
):
    """Create a new story."""
    try:
        return await stories.create_story(data, db)
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/stories/{story_id}", response_model=models.StoryDetailResponse)
async def get_story(
    story_id: str,
    db: Session = Depends(get_db),
):
    """Get a story with all its items."""
    story = await stories.get_story(story_id, db)
    if not story:
        raise HTTPException(status_code=404, detail="Story not found")
    return story


@app.put("/stories/{story_id}", response_model=models.StoryResponse)
async def update_story(
    story_id: str,
    data: models.StoryCreate,
    db: Session = Depends(get_db),
):
    """Update a story."""
    story = await stories.update_story(story_id, data, db)
    if not story:
        raise HTTPException(status_code=404, detail="Story not found")
    return story


@app.delete("/stories/{story_id}")
async def delete_story(
    story_id: str,
    db: Session = Depends(get_db),
):
    """Delete a story."""
    success = await stories.delete_story(story_id, db)
    if not success:
        raise HTTPException(status_code=404, detail="Story not found")
    return {"message": "Story deleted successfully"}


@app.post("/stories/{story_id}/items", response_model=models.StoryItemDetail)
async def add_story_item(
    story_id: str,
    data: models.StoryItemCreate,
    db: Session = Depends(get_db),
):
    """Add a generation to a story."""
    item = await stories.add_item_to_story(story_id, data, db)
    if not item:
        raise HTTPException(status_code=404, detail="Story or generation not found")
    return item


@app.delete("/stories/{story_id}/items/{item_id}")
async def remove_story_item(
    story_id: str,
    item_id: str,
    db: Session = Depends(get_db),
):
    """Remove a story item from a story."""
    success = await stories.remove_item_from_story(story_id, item_id, db)
    if not success:
        raise HTTPException(status_code=404, detail="Story item not found")
    return {"message": "Item removed successfully"}


@app.put("/stories/{story_id}/items/times")
async def update_story_item_times(
    story_id: str,
    data: models.StoryItemBatchUpdate,
    db: Session = Depends(get_db),
):
    """Update story item timecodes."""
    success = await stories.update_story_item_times(story_id, data, db)
    if not success:
        raise HTTPException(status_code=400, detail="Invalid timecode update request")
    return {"message": "Item timecodes updated successfully"}


@app.put("/stories/{story_id}/items/reorder", response_model=List[models.StoryItemDetail])
async def reorder_story_items(
    story_id: str,
    data: models.StoryItemReorder,
    db: Session = Depends(get_db),
):
    """Reorder story items and recalculate timecodes."""
    items = await stories.reorder_story_items(story_id, data.generation_ids, db)
    if items is None:
        raise HTTPException(status_code=400, detail="Invalid reorder request - ensure all generation IDs belong to this story")
    return items


@app.put("/stories/{story_id}/items/{item_id}/move", response_model=models.StoryItemDetail)
async def move_story_item(
    story_id: str,
    item_id: str,
    data: models.StoryItemMove,
    db: Session = Depends(get_db),
):
    """Move a story item (update position and/or track)."""
    item = await stories.move_story_item(story_id, item_id, data, db)
    if item is None:
        raise HTTPException(status_code=404, detail="Story item not found")
    return item


@app.put("/stories/{story_id}/items/{item_id}/trim", response_model=models.StoryItemDetail)
async def trim_story_item(
    story_id: str,
    item_id: str,
    data: models.StoryItemTrim,
    db: Session = Depends(get_db),
):
    """Trim a story item (update trim_start_ms and trim_end_ms)."""
    item = await stories.trim_story_item(story_id, item_id, data, db)
    if item is None:
        raise HTTPException(status_code=404, detail="Story item not found or invalid trim values")
    return item


@app.post("/stories/{story_id}/items/{item_id}/split", response_model=List[models.StoryItemDetail])
async def split_story_item(
    story_id: str,
    item_id: str,
    data: models.StoryItemSplit,
    db: Session = Depends(get_db),
):
    """Split a story item at a given time, creating two clips."""
    items = await stories.split_story_item(story_id, item_id, data, db)
    if items is None:
        raise HTTPException(status_code=404, detail="Story item not found or invalid split point")
    return items


@app.post("/stories/{story_id}/items/{item_id}/duplicate", response_model=models.StoryItemDetail)
async def duplicate_story_item(
    story_id: str,
    item_id: str,
    db: Session = Depends(get_db),
):
    """Duplicate a story item, creating a copy with all properties."""
    item = await stories.duplicate_story_item(story_id, item_id, db)
    if item is None:
        raise HTTPException(status_code=404, detail="Story item not found")
    return item


@app.get("/stories/{story_id}/export-audio")
async def export_story_audio(
    story_id: str,
    db: Session = Depends(get_db),
):
    """Export story as single mixed audio file with timecode-based mixing."""
    try:
        # Get story to create filename
        story = db.query(database.Story).filter_by(id=story_id).first()
        if not story:
            raise HTTPException(status_code=404, detail="Story not found")
        
        # Export audio
        audio_bytes = await stories.export_story_audio(story_id, db)
        if not audio_bytes:
            raise HTTPException(status_code=400, detail="Story has no audio items")
        
        # Create safe filename
        safe_name = "".join(c for c in story.name if c.isalnum() or c in (' ', '-', '_')).strip()
        if not safe_name:
            safe_name = "story"
        filename = f"{safe_name}.wav"
        
        # Return as streaming response
        return StreamingResponse(
            io.BytesIO(audio_bytes),
            media_type="audio/wav",
            headers={
                "Content-Disposition": _safe_content_disposition("attachment", filename)
            }
        )
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


# ============================================
# FILE SERVING
# ============================================

@app.get("/audio/{generation_id}")
async def get_audio(generation_id: str, db: Session = Depends(get_db)):
    """Serve generated audio file."""
    generation = await history.get_generation(generation_id, db)
    if not generation:
        raise HTTPException(status_code=404, detail="Generation not found")
    
    audio_path = Path(generation.audio_path)
    if not audio_path.exists():
        raise HTTPException(status_code=404, detail="Audio file not found")
    
    return FileResponse(
        audio_path,
        media_type="audio/wav",
        filename=f"generation_{generation_id}.wav",
    )


@app.get("/samples/{sample_id}")
async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)):
    """Serve profile sample audio file."""
    from .database import ProfileSample as DBProfileSample
    
    sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
    if not sample:
        raise HTTPException(status_code=404, detail="Sample not found")
    
    audio_path = Path(sample.audio_path)
    if not audio_path.exists():
        raise HTTPException(status_code=404, detail="Audio file not found")
    
    return FileResponse(
        audio_path,
        media_type="audio/wav",
        filename=f"sample_{sample_id}.wav",
    )


# ============================================
# MODEL MANAGEMENT
# ============================================

@app.post("/models/load")
async def load_model(model_size: str = "1.7B"):
    """Manually load TTS model."""
    try:
        tts_model = tts.get_tts_model()
        await tts_model.load_model_async(model_size)
        return {"message": f"Model {model_size} loaded successfully"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/models/unload")
async def unload_model():
    """Unload TTS model to free memory."""
    try:
        tts.unload_tts_model()
        return {"message": "Model unloaded successfully"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/models/progress/{model_name}")
async def get_model_progress(model_name: str):
    """Get model download progress via Server-Sent Events."""
    from fastapi.responses import StreamingResponse
    
    progress_manager = get_progress_manager()
    
    async def event_generator():
        """Generate SSE events for progress updates."""
        async for event in progress_manager.subscribe(model_name):
            yield event
    
    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


@app.get("/models/status", response_model=models.ModelStatusListResponse)
async def get_model_status():
    """Get status of all available models."""
    from huggingface_hub import constants as hf_constants
    from pathlib import Path
    
    backend_type = get_backend_type()
    task_manager = get_task_manager()
    
    # Get set of currently downloading model names
    active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
    
    # Try to import scan_cache_dir (might not be available in older versions)
    try:
        from huggingface_hub import scan_cache_dir
        use_scan_cache = True
    except ImportError:
        use_scan_cache = False
    
    def check_tts_loaded(model_size: str):
        """Check if TTS model is loaded with specific size."""
        try:
            tts_model = tts.get_tts_model()
            return tts_model.is_loaded() and getattr(tts_model, 'model_size', None) == model_size
        except Exception:
            return False
    
    def check_whisper_loaded(model_size: str):
        """Check if Whisper model is loaded with specific size."""
        try:
            whisper_model = transcribe.get_whisper_model()
            return whisper_model.is_loaded() and getattr(whisper_model, 'model_size', None) == model_size
        except Exception:
            return False
    
    # Use backend-specific model IDs
    if backend_type == "mlx":
        tts_1_7b_id = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
        tts_0_6b_id = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"  # Fallback to 1.7B
        # MLX backend uses openai/whisper-* models, not mlx-community
        whisper_base_id = "openai/whisper-base"
        whisper_small_id = "openai/whisper-small"
        whisper_medium_id = "openai/whisper-medium"
        whisper_large_id = "openai/whisper-large"
    else:
        tts_1_7b_id = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
        tts_0_6b_id = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
        whisper_base_id = "openai/whisper-base"
        whisper_small_id = "openai/whisper-small"
        whisper_medium_id = "openai/whisper-medium"
        whisper_large_id = "openai/whisper-large"
    
    model_configs = [
        {
            "model_name": "qwen-tts-1.7B",
            "display_name": "Qwen TTS 1.7B",
            "hf_repo_id": tts_1_7b_id,
            "model_size": "1.7B",
            "check_loaded": lambda: check_tts_loaded("1.7B"),
        },
        {
            "model_name": "qwen-tts-0.6B",
            "display_name": "Qwen TTS 0.6B",
            "hf_repo_id": tts_0_6b_id,
            "model_size": "0.6B",
            "check_loaded": lambda: check_tts_loaded("0.6B"),
        },
        {
            "model_name": "whisper-base",
            "display_name": "Whisper Base",
            "hf_repo_id": whisper_base_id,
            "model_size": "base",
            "check_loaded": lambda: check_whisper_loaded("base"),
        },
        {
            "model_name": "whisper-small",
            "display_name": "Whisper Small",
            "hf_repo_id": whisper_small_id,
            "model_size": "small",
            "check_loaded": lambda: check_whisper_loaded("small"),
        },
        {
            "model_name": "whisper-medium",
            "display_name": "Whisper Medium",
            "hf_repo_id": whisper_medium_id,
            "model_size": "medium",
            "check_loaded": lambda: check_whisper_loaded("medium"),
        },
        {
            "model_name": "whisper-large",
            "display_name": "Whisper Large",
            "hf_repo_id": whisper_large_id,
            "model_size": "large",
            "check_loaded": lambda: check_whisper_loaded("large"),
        },
    ]
    
    # Build a mapping of model_name -> hf_repo_id so we can check if shared repos are downloading
    model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
    
    # Get the set of hf_repo_ids that are currently being downloaded
    # This handles the case where multiple models share the same repo (e.g., 0.6B and 1.7B on MLX)
    active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
    
    # Get HuggingFace cache info (if available)
    cache_info = None
    if use_scan_cache:
        try:
            cache_info = scan_cache_dir()
        except Exception:
            # Function failed, continue without it
            pass
    
    statuses = []
    
    for config in model_configs:
        try:
            downloaded = False
            size_mb = None
            loaded = False
            
            # Method 1: Try using scan_cache_dir if available
            if cache_info:
                repo_id = config["hf_repo_id"]
                for repo in cache_info.repos:
                    if repo.repo_id == repo_id:
                        # Check if actual model weight files exist (not just config files)
                        # scan_cache_dir only shows completed files, so check if any are model weights
                        has_model_weights = False
                        for rev in repo.revisions:
                            for f in rev.files:
                                fname = f.file_name.lower()
                                if fname.endswith(('.safetensors', '.bin', '.pt', '.pth', '.npz')):
                                    has_model_weights = True
                                    break
                            if has_model_weights:
                                break
                        
                        # Also check for .incomplete files in blobs directory (downloads in progress)
                        has_incomplete = False
                        try:
                            cache_dir = hf_constants.HF_HUB_CACHE
                            blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
                            if blobs_dir.exists():
                                has_incomplete = any(blobs_dir.glob("*.incomplete"))
                        except Exception:
                            pass
                        
                        # Only mark as downloaded if we have model weights AND no incomplete files
                        if has_model_weights and not has_incomplete:
                            downloaded = True
                            # Calculate size from cache info
                            try:
                                total_size = sum(revision.size_on_disk for revision in repo.revisions)
                                size_mb = total_size / (1024 * 1024)
                            except Exception:
                                pass
                        break
            
            # Method 2: Fallback to checking cache directory directly (using HuggingFace's OS-specific cache location)
            if not downloaded:
                try:
                    cache_dir = hf_constants.HF_HUB_CACHE
                    repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
                    
                    if repo_cache.exists():
                        # Check for .incomplete files - if any exist, download is still in progress
                        blobs_dir = repo_cache / "blobs"
                        has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
                        
                        if not has_incomplete:
                            # Check for actual model weight files (not just index files)
                            # in the snapshots directory (symlinks to completed blobs)
                            snapshots_dir = repo_cache / "snapshots"
                            has_model_files = False
                            if snapshots_dir.exists():
                                has_model_files = (
                                    any(snapshots_dir.rglob("*.bin")) or
                                    any(snapshots_dir.rglob("*.safetensors")) or
                                    any(snapshots_dir.rglob("*.pt")) or
                                    any(snapshots_dir.rglob("*.pth")) or
                                    any(snapshots_dir.rglob("*.npz"))
                                )
                            
                            if has_model_files:
                                downloaded = True
                                # Calculate size (exclude .incomplete files)
                                try:
                                    total_size = sum(
                                        f.stat().st_size for f in repo_cache.rglob("*") 
                                        if f.is_file() and not f.name.endswith('.incomplete')
                                    )
                                    size_mb = total_size / (1024 * 1024)
                                except Exception:
                                    pass
                except Exception:
                    pass
            
            # Method 3 removed - checking for config.json is too lenient
            # Methods 1 and 2 properly verify that model weight files exist
            
            # Check if loaded in memory
            try:
                loaded = config["check_loaded"]()
            except Exception:
                loaded = False
            
            # Check if this model (or its shared repo) is currently being downloaded
            is_downloading = config["hf_repo_id"] in active_download_repos
            
            # If downloading, don't report as downloaded (partial files exist)
            if is_downloading:
                downloaded = False
                size_mb = None  # Don't show partial size during download
            
            statuses.append(models.ModelStatus(
                model_name=config["model_name"],
                display_name=config["display_name"],
                downloaded=downloaded,
                downloading=is_downloading,
                size_mb=size_mb,
                loaded=loaded,
            ))
        except Exception as e:
            # If check fails, try to at least check if loaded
            try:
                loaded = config["check_loaded"]()
            except Exception:
                loaded = False
            
            # Check if this model (or its shared repo) is currently being downloaded
            is_downloading = config["hf_repo_id"] in active_download_repos
            
            statuses.append(models.ModelStatus(
                model_name=config["model_name"],
                display_name=config["display_name"],
                downloaded=False,  # Assume not downloaded if check failed
                downloading=is_downloading,
                size_mb=None,
                loaded=loaded,
            ))
    
    return models.ModelStatusListResponse(models=statuses)


@app.post("/models/download")
async def trigger_model_download(request: models.ModelDownloadRequest):
    """Trigger download of a specific model."""
    import asyncio
    
    task_manager = get_task_manager()
    progress_manager = get_progress_manager()
    
    model_configs = {
        "qwen-tts-1.7B": {
            "model_size": "1.7B",
            "load_func": lambda: tts.get_tts_model().load_model("1.7B"),
        },
        "qwen-tts-0.6B": {
            "model_size": "0.6B",
            "load_func": lambda: tts.get_tts_model().load_model("0.6B"),
        },
        "whisper-base": {
            "model_size": "base",
            "load_func": lambda: transcribe.get_whisper_model().load_model("base"),
        },
        "whisper-small": {
            "model_size": "small",
            "load_func": lambda: transcribe.get_whisper_model().load_model("small"),
        },
        "whisper-medium": {
            "model_size": "medium",
            "load_func": lambda: transcribe.get_whisper_model().load_model("medium"),
        },
        "whisper-large": {
            "model_size": "large",
            "load_func": lambda: transcribe.get_whisper_model().load_model("large"),
        },
    }
    
    if request.model_name not in model_configs:
        raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
    
    config = model_configs[request.model_name]
    
    async def download_in_background():
        """Download model in background without blocking the HTTP request."""
        try:
            # Call the load function (which may be async)
            result = config["load_func"]()
            # If it's a coroutine, await it
            if asyncio.iscoroutine(result):
                await result
            task_manager.complete_download(request.model_name)
        except Exception as e:
            task_manager.error_download(request.model_name, str(e))

    # Start tracking download
    task_manager.start_download(request.model_name)
    
    # Initialize progress state so SSE endpoint has initial data to send.
    # This fixes a race condition where the frontend connects to SSE before
    # any progress callbacks have fired (especially for large models like Qwen
    # where huggingface_hub takes time to fetch metadata for all files).
    progress_manager.update_progress(
        model_name=request.model_name,
        current=0,
        total=0,  # Will be updated once actual total is known
        filename="Connecting to HuggingFace...",
        status="downloading",
    )

    # Start download in background task (don't await)
    asyncio.create_task(download_in_background())

    # Return immediately - frontend should poll progress endpoint
    return {"message": f"Model {request.model_name} download started"}


@app.delete("/models/{model_name}")
async def delete_model(model_name: str):
    """Delete a downloaded model from the HuggingFace cache."""
    import shutil
    import os
    from huggingface_hub import constants as hf_constants
    
    # Map model names to HuggingFace repo IDs
    model_configs = {
        "qwen-tts-1.7B": {
            "hf_repo_id": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
            "model_size": "1.7B",
            "model_type": "tts",
        },
        "qwen-tts-0.6B": {
            "hf_repo_id": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
            "model_size": "0.6B",
            "model_type": "tts",
        },
        "whisper-base": {
            "hf_repo_id": "openai/whisper-base",
            "model_size": "base",
            "model_type": "whisper",
        },
        "whisper-small": {
            "hf_repo_id": "openai/whisper-small",
            "model_size": "small",
            "model_type": "whisper",
        },
        "whisper-medium": {
            "hf_repo_id": "openai/whisper-medium",
            "model_size": "medium",
            "model_type": "whisper",
        },
        "whisper-large": {
            "hf_repo_id": "openai/whisper-large",
            "model_size": "large",
            "model_type": "whisper",
        },
    }
    
    if model_name not in model_configs:
        raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
    
    config = model_configs[model_name]
    hf_repo_id = config["hf_repo_id"]
    
    try:
        # Check if model is loaded and unload it first
        if config["model_type"] == "tts":
            tts_model = tts.get_tts_model()
            if tts_model.is_loaded() and tts_model.model_size == config["model_size"]:
                tts.unload_tts_model()
        elif config["model_type"] == "whisper":
            whisper_model = transcribe.get_whisper_model()
            if whisper_model.is_loaded() and whisper_model.model_size == config["model_size"]:
                transcribe.unload_whisper_model()
        
        # Find and delete the cache directory (using HuggingFace's OS-specific cache location)
        cache_dir = hf_constants.HF_HUB_CACHE
        repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
        
        # Check if the cache directory exists
        if not repo_cache_dir.exists():
            raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
        
        # Delete the entire cache directory for this model
        try:
            shutil.rmtree(repo_cache_dir)
        except OSError as e:
            raise HTTPException(
                status_code=500,
                detail=f"Failed to delete model cache directory: {str(e)}"
            )
        
        return {"message": f"Model {model_name} deleted successfully"}
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")


@app.post("/cache/clear")
async def clear_cache():
    """Clear all voice prompt caches (memory and disk)."""
    try:
        deleted_count = clear_voice_prompt_cache()
        return {
            "message": f"Voice prompt cache cleared successfully",
            "files_deleted": deleted_count,
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}")


# ============================================
# TASK MANAGEMENT
# ============================================

@app.get("/tasks/active", response_model=models.ActiveTasksResponse)
async def get_active_tasks():
    """Return all currently active downloads and generations."""
    task_manager = get_task_manager()
    progress_manager = get_progress_manager()
    
    # Get active downloads from both task manager and progress manager
    # Task manager tracks which downloads are active
    # Progress manager has the actual progress data
    active_downloads = []
    task_manager_downloads = task_manager.get_active_downloads()
    progress_active = progress_manager.get_all_active()
    
    # Combine data from both sources
    download_map = {task.model_name: task for task in task_manager_downloads}
    progress_map = {p["model_name"]: p for p in progress_active}
    
    # Create unified list
    all_model_names = set(download_map.keys()) | set(progress_map.keys())
    for model_name in all_model_names:
        task = download_map.get(model_name)
        progress = progress_map.get(model_name)
        
        if task:
            active_downloads.append(models.ActiveDownloadTask(
                model_name=model_name,
                status=task.status,
                started_at=task.started_at,
            ))
        elif progress:
            # Progress exists but no task - create from progress data
            timestamp_str = progress.get("timestamp")
            if timestamp_str:
                try:
                    started_at = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
                except (ValueError, AttributeError):
                    started_at = datetime.utcnow()
            else:
                started_at = datetime.utcnow()
            
            active_downloads.append(models.ActiveDownloadTask(
                model_name=model_name,
                status=progress.get("status", "downloading"),
                started_at=started_at,
            ))
    
    # Get active generations
    active_generations = []
    for gen_task in task_manager.get_active_generations():
        active_generations.append(models.ActiveGenerationTask(
            task_id=gen_task.task_id,
            profile_id=gen_task.profile_id,
            text_preview=gen_task.text_preview,
            started_at=gen_task.started_at,
        ))
    
    return models.ActiveTasksResponse(
        downloads=active_downloads,
        generations=active_generations,
    )


# ============================================
# STARTUP & SHUTDOWN
# ============================================

def _get_gpu_status() -> str:
    """Get GPU availability status."""
    backend_type = get_backend_type()
    if torch.cuda.is_available():
        return f"CUDA ({torch.cuda.get_device_name(0)})"
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return "MPS (Apple Silicon)"
    elif backend_type == "mlx":
        return "Metal (Apple Silicon via MLX)"
    return "None (CPU only)"


@app.on_event("startup")
async def startup_event():
    """Run on application startup."""
    print("voicebox API starting up...")
    database.init_db()
    print(f"Database initialized at {database._db_path}")
    backend_type = get_backend_type()
    print(f"Backend: {backend_type.upper()}")
    print(f"GPU available: {_get_gpu_status()}")

    # Initialize progress manager with main event loop for thread-safe operations
    try:
        progress_manager = get_progress_manager()
        progress_manager._set_main_loop(asyncio.get_running_loop())
        print("Progress manager initialized with event loop")
    except Exception as e:
        print(f"Warning: Could not initialize progress manager event loop: {e}")

    # Ensure HuggingFace cache directory exists
    try:
        from huggingface_hub import constants as hf_constants
        cache_dir = Path(hf_constants.HF_HUB_CACHE)
        cache_dir.mkdir(parents=True, exist_ok=True)
        print(f"HuggingFace cache directory: {cache_dir}")
    except Exception as e:
        print(f"Warning: Could not create HuggingFace cache directory: {e}")
        print("Model downloads may fail. Please ensure the directory exists and has write permissions.")


@app.on_event("shutdown")
async def shutdown_event():
    """Run on application shutdown."""
    print("voicebox API shutting down...")
    # Unload models to free memory
    tts.unload_tts_model()
    transcribe.unload_whisper_model()


# ============================================
# MAIN
# ============================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="voicebox backend server")
    parser.add_argument(
        "--host",
        type=str,
        default="127.0.0.1",
        help="Host to bind to (use 0.0.0.0 for remote access)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="Port to bind to",
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        default=None,
        help="Data directory for database, profiles, and generated audio",
    )
    args = parser.parse_args()

    # Set data directory if provided
    if args.data_dir:
        config.set_data_dir(args.data_dir)

    # Initialize database after data directory is set
    database.init_db()

    uvicorn.run(
        "backend.main:app",
        host=args.host,
        port=args.port,
        reload=False,  # Disable reload in production
    )

2. backend/utils/cache.py

"""
Voice prompt caching utilities.
"""

import hashlib
import torch
import warnings
from pathlib import Path
from typing import Optional, Union, Dict, Any

from .. import config


def _get_cache_dir() -> Path:
    """Get cache directory from config."""
    return config.get_cache_dir()


# In-memory cache - can store dict (voice prompt) or tensor (legacy)
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}


def get_cache_key(audio_path: str, reference_text: str) -> str:
    """
    Generate cache key from audio file and reference text.

    Args:
        audio_path: Path to audio file
        reference_text: Reference text

    Returns:
        Cache key (MD5 hash)
    """
    # Read audio file
    with open(audio_path, "rb") as f:
        audio_bytes = f.read()

    # Combine audio bytes and text
    combined = audio_bytes + reference_text.encode("utf-8")

    # Generate hash
    return hashlib.md5(combined).hexdigest()


def get_cached_voice_prompt(
    cache_key: str,
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
    """
    Get cached voice prompt if available.

    Args:
        cache_key: Cache key

    Returns:
        Cached voice prompt (dict or tensor) or None
    """
    # Check in-memory cache
    if cache_key in _memory_cache:
        return _memory_cache[cache_key]

    # Check disk cache
    cache_file = _get_cache_dir() / f"{cache_key}.prompt"
    if cache_file.exists():
        try:
            # 抑制 FutureWarning 并使用 weights_only=False
            # 因为我们缓存的是自己生成的 voice prompt,是可信的
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=FutureWarning)
                prompt = torch.load(cache_file, weights_only=False)
            _memory_cache[cache_key] = prompt
            return prompt
        except Exception:
            # Cache file corrupted, delete it
            cache_file.unlink()

    return None


def cache_voice_prompt(
    cache_key: str,
    voice_prompt: Union[torch.Tensor, Dict[str, Any]],
) -> None:
    """
    Cache voice prompt to memory and disk.

    Args:
        cache_key: Cache key
        voice_prompt: Voice prompt (dict or tensor)
    """
    # Store in memory
    _memory_cache[cache_key] = voice_prompt

    # Store on disk (torch.save can handle both dicts and tensors)
    cache_file = _get_cache_dir() / f"{cache_key}.prompt"
    torch.save(voice_prompt, cache_file)


def clear_voice_prompt_cache() -> int:
    """
    Clear all voice prompt caches (memory and disk).
    
    Returns:
        Number of cache files deleted
    """
    # Clear memory cache
    _memory_cache.clear()
    
    # Clear disk cache
    cache_dir = _get_cache_dir()
    deleted_count = 0
    
    if cache_dir.exists():
        # Delete prompt cache files
        for cache_file in cache_dir.glob("*.prompt"):
            try:
                cache_file.unlink()
                deleted_count += 1
            except Exception as e:
                print(f"Failed to delete cache file {cache_file}: {e}")
        
        # Delete combined audio files
        for audio_file in cache_dir.glob("combined_*.wav"):
            try:
                audio_file.unlink()
                deleted_count += 1
            except Exception as e:
                print(f"Failed to delete combined audio file {audio_file}: {e}")
    
    return deleted_count


def clear_profile_cache(profile_id: str) -> int:
    """
    Clear cache files for a specific profile.
    
    Args:
        profile_id: Profile ID
    
    Returns:
        Number of cache files deleted
    """
    cache_dir = _get_cache_dir()
    deleted_count = 0
    
    if cache_dir.exists():
        # Delete combined audio files for this profile
        pattern = f"combined_{profile_id}_*.wav"
        for audio_file in cache_dir.glob(pattern):
            try:
                audio_file.unlink()
                deleted_count += 1
            except Exception as e:
                print(f"Failed to delete combined audio file {audio_file}: {e}")
    
    return deleted_count

3. backend/utils/hf_config.py

"""
HuggingFace Hub 配置工具模块。

提供以下功能:
1. 禁用SSL证书验证
2. 使用国内镜像加速下载
3. 优化本地模型加载,避免重复访问网络
"""

import os
import ssl
import warnings
from pathlib import Path
from typing import Optional

# 国内镜像地址
HF_MIRRORS = [
    "https://hf-mirror.com",
    "https://modelscope.cn/api/v1/models",
]


def configure_huggingface_hub(
    disable_ssl_verify: bool = True,
    mirror_url: Optional[str] = None,
    local_files_only: Optional[bool] = None,
):
    """
    配置HuggingFace Hub参数。

    Args:
        disable_ssl_verify: 是否禁用SSL证书验证
        mirror_url: 镜像URL,如果为None则使用默认镜像
        local_files_only: 是否仅使用本地文件,如果为None则自动判断
    """
    # 禁用SSL证书验证
    if disable_ssl_verify:
        _disable_ssl_verification()

    # 设置镜像
    if mirror_url:
        os.environ["HF_ENDPOINT"] = mirror_url
    elif "HF_ENDPOINT" not in os.environ:
        # 使用默认镜像
        os.environ["HF_ENDPOINT"] = HF_MIRRORS[0]

    # 设置本地文件优先模式
    if local_files_only is not None:
        os.environ["HF_HUB_OFFLINE"] = "1" if local_files_only else "0"


def _disable_ssl_verification():
    """禁用SSL证书验证。"""
    try:
        # 禁用SSL验证警告
        ssl._create_default_https_context = ssl._create_unverified_context

        # 抑制SSL相关警告
        warnings.filterwarnings("ignore", message="Unverified HTTPS request")
        warnings.filterwarnings("ignore", category=UserWarning, message=".*SSL.*")

        # 尝试禁用 urllib3 的警告
        try:
            import urllib3
            urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
        except Exception:
            pass

    except Exception as e:
        print(f"[HF Config] Warning: Could not disable SSL verification: {e}")


def get_model_download_kwargs(
    model_name: str,
    is_cached: bool,
    force_download: bool = False,
) -> dict:
    """
    获取模型下载参数。

    Args:
        model_name: 模型名称
        is_cached: 模型是否已缓存
        force_download: 是否强制重新下载

    Returns:
        包含下载参数的字典
    """
    kwargs = {}

    # 如果模型已缓存且不强制下载,则优先使用本地文件
    if is_cached and not force_download:
        kwargs["local_files_only"] = True
        print(f"[HF Config] Using local files for {model_name}")
    else:
        # 需要下载时,禁用SSL验证
        kwargs["local_files_only"] = False
        kwargs["trust_remote_code"] = True
        print(f"[HF Config] Will download {model_name} from HuggingFace Hub")

    return kwargs


def is_model_fully_cached(
    model_id: str,
    cache_dir: Optional[str] = None,
) -> bool:
    """
    检查模型是否已完全缓存。

    Args:
        model_id: HuggingFace模型ID
        cache_dir: 缓存目录,如果为None则使用默认目录

    Returns:
        模型是否已完全缓存
    """
    try:
        from huggingface_hub import constants as hf_constants

        # 获取缓存目录
        if cache_dir is None:
            cache_dir = hf_constants.HF_HUB_CACHE

        # 构建模型缓存路径
        repo_cache = Path(cache_dir) / ("models--" + model_id.replace("/", "--"))

        if not repo_cache.exists():
            return False

        # 检查是否有未完成的下载
        blobs_dir = repo_cache / "blobs"
        if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
            print(f"[HF Config] Found incomplete downloads for {model_id}")
            return False

        # 检查是否有模型权重文件
        snapshots_dir = repo_cache / "snapshots"
        if snapshots_dir.exists():
            has_weights = (
                any(snapshots_dir.rglob("*.safetensors")) or
                any(snapshots_dir.rglob("*.bin")) or
                any(snapshots_dir.rglob("*.pt")) or
                any(snapshots_dir.rglob("*.pth")) or
                any(snapshots_dir.rglob("*.npz"))
            )
            if not has_weights:
                print(f"[HF Config] No model weights found for {model_id}")
                return False

        return True
    except Exception as e:
        print(f"[HF Config] Error checking cache for {model_id}: {e}")
        return False


def setup_huggingface_for_offline():
    """
    设置HuggingFace为离线模式,仅使用本地文件。

    当模型已下载到本地时,调用此函数可以避免访问网络。
    此函数必须在导入 transformers 或 huggingface_hub 之前调用。
    """
    # 设置环境变量
    os.environ["HF_HUB_OFFLINE"] = "1"
    os.environ["TRANSFORMERS_OFFLINE"] = "1"

    # 禁用 huggingface_hub 的网络请求
    os.environ["HF_UPDATE_DOWNLOAD_COUNTS"] = "0"
    os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

    # 设置 transformers 的离线模式
    try:
        import transformers
        transformers.utils.hub._is_offline_mode = True
    except Exception:
        pass

    # 尝试设置 huggingface_hub 的离线模式
    try:
        import huggingface_hub
        huggingface_hub.constants.HF_HUB_OFFLINE = True
    except Exception:
        pass

    print("[HF Config] Set to offline mode - will only use local files")


def setup_huggingface_for_online():
    """
    设置HuggingFace为在线模式,允许从网络下载。
    """
    os.environ["HF_HUB_OFFLINE"] = "0"
    os.environ["TRANSFORMERS_OFFLINE"] = "0"
    os.environ["HF_UPDATE_DOWNLOAD_COUNTS"] = "1"

    # 设置 transformers 的在线模式
    try:
        import transformers
        transformers.utils.hub._is_offline_mode = False
    except Exception:
        pass

    # 尝试设置 huggingface_hub 的在线模式
    try:
        import huggingface_hub
        huggingface_hub.constants.HF_HUB_OFFLINE = False
    except Exception:
        pass

    print("[HF Config] Set to online mode - will download from network if needed")


def get_huggingface_config() -> dict:
    """
    获取当前HuggingFace配置信息。

    Returns:
        包含配置信息的字典
    """
    return {
        "HF_ENDPOINT": os.environ.get("HF_ENDPOINT"),
        "HF_HUB_OFFLINE": os.environ.get("HF_HUB_OFFLINE"),
        "TRANSFORMERS_OFFLINE": os.environ.get("TRANSFORMERS_OFFLINE"),
        "HF_HUB_CACHE": os.environ.get("HF_HUB_CACHE"),
    }


def clean_no_exist_cache(model_id: str, cache_dir: Optional[str] = None) -> int:
    """
    清理模型的 .no_exist 目录。

    .no_exist 目录记录了哪些文件在远程仓库中不存在。
    在离线模式下,这些记录可能导致不必要的网络访问尝试。
    清理这个目录可以让 HuggingFace Hub 在离线模式下不再尝试验证这些文件。

    Args:
        model_id: HuggingFace模型ID
        cache_dir: 缓存目录,如果为None则使用默认目录

    Returns:
        删除的文件数量
    """
    try:
        from huggingface_hub import constants as hf_constants
        import shutil

        # 获取缓存目录
        if cache_dir is None:
            cache_dir = hf_constants.HF_HUB_CACHE

        # 构建 .no_exist 目录路径
        repo_cache = Path(cache_dir) / ("models--" + model_id.replace("/", "--"))
        no_exist_dir = repo_cache / ".no_exist"

        if not no_exist_dir.exists():
            print(f"[HF Config] No .no_exist directory found for {model_id}")
            return 0

        # 统计文件数量
        file_count = sum(1 for _ in no_exist_dir.rglob("*") if _.is_file())

        # 删除 .no_exist 目录
        shutil.rmtree(no_exist_dir)
        print(f"[HF Config] Cleaned .no_exist directory for {model_id} ({file_count} files removed)")

        return file_count
    except Exception as e:
        print(f"[HF Config] Error cleaning .no_exist directory for {model_id}: {e}")
        return 0


def clean_all_no_exist_cache(cache_dir: Optional[str] = None) -> int:
    """
    清理所有模型的 .no_exist 目录。

    Args:
        cache_dir: 缓存目录,如果为None则使用默认目录

    Returns:
        删除的总文件数量
    """
    try:
        from huggingface_hub import constants as hf_constants
        import shutil

        # 获取缓存目录
        if cache_dir is None:
            cache_dir = hf_constants.HF_HUB_CACHE

        cache_path = Path(cache_dir)
        total_files = 0

        # 查找所有模型的 .no_exist 目录
        for model_dir in cache_path.glob("models--*"):
            no_exist_dir = model_dir / ".no_exist"
            if no_exist_dir.exists():
                file_count = sum(1 for _ in no_exist_dir.rglob("*") if _.is_file())
                shutil.rmtree(no_exist_dir)
                print(f"[HF Config] Cleaned .no_exist for {model_dir.name} ({file_count} files)")
                total_files += file_count

        print(f"[HF Config] Total .no_exist files cleaned: {total_files}")
        return total_files
    except Exception as e:
        print(f"[HF Config] Error cleaning all .no_exist directories: {e}")
        return 0


# 初始化时自动配置
configure_huggingface_hub(disable_ssl_verify=True)

4. backend/utils/warning_suppressor.py

"""
警告抑制工具模块。

在应用启动时抑制常见的警告信息,保持日志清洁。
"""

import warnings
import os


def suppress_common_warnings():
    """
    抑制常见的警告信息。

    包括:
    - torch.load 的 FutureWarning
    - transformers 的 UserWarning
    - 其他已知的无害警告
    """
    # 抑制 torch.load 的 FutureWarning
    warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*")

    # 抑制 transformers 的 Flash Attention 警告
    warnings.filterwarnings("ignore", category=UserWarning, message=".*flash attention.*")

    # 抑制 transformers 的 pad_token_id 警告
    warnings.filterwarnings("ignore", category=UserWarning, message=".*pad_token_id.*")

    # 抑制 transformers 的 UserWarning
    warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

    # 设置环境变量来抑制 transformers 的警告
    os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"

    # 抑制 Python 的 DeprecationWarning
    warnings.filterwarnings("ignore", category=DeprecationWarning)

    print("[Warning Suppressor] Common warnings suppressed")


def suppress_ssl_warnings():
    """抑制 SSL 相关的警告。"""
    import urllib3

    # 禁用 InsecureRequestWarning
    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

    print("[Warning Suppressor] SSL warnings suppressed")

5. backend/backends/pytorch_backend.py

"""
PyTorch backend implementation for TTS and STT.
"""

from typing import Optional, List, Tuple
import asyncio
import torch
import numpy as np
from pathlib import Path

from . import TTSBackend, STTBackend
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.audio import normalize_audio, load_audio
from ..utils.progress import get_progress_manager
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager
from ..utils.hf_config import (
    get_model_download_kwargs,
    setup_huggingface_for_online,
)


class PyTorchTTSBackend:
    """PyTorch-based TTS backend using Qwen3-TTS."""
    
    def __init__(self, model_size: str = "1.7B"):
        self.model = None
        self.model_size = model_size
        self.device = self._get_device()
        self._current_model_size = None
    
    def _get_device(self) -> str:
        """Get the best available device."""
        if torch.cuda.is_available():
            return "cuda"
        # Intel Arc / Intel Xe GPU via intel-extension-for-pytorch (IPEX)
        try:
            import intel_extension_for_pytorch  # noqa: F401
            if hasattr(torch, 'xpu') and torch.xpu.is_available():
                return "xpu"
        except ImportError:
            pass
        # Any GPU on Windows via DirectML (torch-directml)
        try:
            import torch_directml
            if torch_directml.device_count() > 0:
                return torch_directml.device(0)
        except ImportError:
            pass
        # MPS (Apple Silicon) — kept for completeness but MLX backend is preferred
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            return "cpu"  # MPS disabled for stability; MLX backend handles Apple Silicon
        return "cpu"
    
    def is_loaded(self) -> bool:
        """Check if model is loaded."""
        return self.model is not None
    
    def _get_model_path(self, model_size: str) -> str:
        """
        Get the HuggingFace Hub model ID.
        
        Args:
            model_size: Model size (1.7B or 0.6B)
            
        Returns:
            HuggingFace Hub model ID
        """
        hf_model_map = {
            "1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
            "0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
        }
        
        if model_size not in hf_model_map:
            raise ValueError(f"Unknown model size: {model_size}")
        
        return hf_model_map[model_size]
    
    def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
        """
        Get the local snapshot path if the model is fully cached.

        Args:
            model_size: Model size to check

        Returns:
            Path to local snapshot if fully cached, None otherwise
        """
        try:
            from huggingface_hub import constants as hf_constants
            model_id = self._get_model_path(model_size)
            repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))

            if not repo_cache.exists():
                return None

            # Check for .incomplete files - if any exist, download is still in progress
            blobs_dir = repo_cache / "blobs"
            if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
                return None

            # Check that actual model weight files exist in snapshots
            snapshots_dir = repo_cache / "snapshots"
            if not snapshots_dir.exists():
                return None

            # Get the latest snapshot (by modification time)
            snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
            if not snapshot_dirs:
                return None

            latest_snapshot = snapshot_dirs[0]

            # Check for model weights (actual files, not just symlinks)
            has_weights = (
                any(latest_snapshot.rglob("*.safetensors")) or
                any(latest_snapshot.rglob("*.bin"))
            )
            if not has_weights:
                return None

            # Check for config.json
            if not (latest_snapshot / "config.json").exists():
                return None

            return latest_snapshot
        except Exception as e:
            print(f"[_get_local_snapshot_path] Error: {e}")
            return None

    def _is_model_cached(self, model_size: str) -> bool:
        """
        Check if the model is already cached locally AND fully downloaded.

        Args:
            model_size: Model size to check

        Returns:
            True if model is fully cached, False if missing or incomplete
        """
        local_path = self._get_local_snapshot_path(model_size)
        if local_path:
            print(f"[_is_model_cached] Model {model_size} is fully cached at {local_path}")
        else:
            print(f"[_is_model_cached] Model {model_size} is not cached")
        return local_path is not None
    
    async def load_model_async(self, model_size: Optional[str] = None):
        """
        Lazy load the TTS model with automatic downloading from HuggingFace Hub.
        
        Args:
            model_size: Model size to load (1.7B or 0.6B)
        """
        if model_size is None:
            model_size = self.model_size
            
        # If already loaded with correct size, return
        if self.model is not None and self._current_model_size == model_size:
            return
        
        # Unload existing model if different size requested
        if self.model is not None and self._current_model_size != model_size:
            self.unload_model()
        
        # Run blocking load in thread pool
        await asyncio.to_thread(self._load_model_sync, model_size)
    
    # Alias for compatibility
    load_model = load_model_async
    
    def _load_model_sync(self, model_size: str):
        """Synchronous model loading."""
        try:
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            model_name = f"qwen-tts-{model_size}"

            # Get local snapshot path if model is cached
            local_snapshot_path = self._get_local_snapshot_path(model_size)
            is_cached = local_snapshot_path is not None

            # Get model ID for HuggingFace Hub (used for downloading)
            model_id = self._get_model_path(model_size)

            # Determine the path to use for loading
            # If cached, use local snapshot path directly to avoid any network access
            # If not cached, use HuggingFace Hub ID to download
            load_path = str(local_snapshot_path) if is_cached else model_id

            if is_cached:
                print(f"[TTS] Loading model {model_size} from local cache: {load_path}")
            else:
                print(f"[TTS] Model {model_size} not cached, will download from HuggingFace Hub")
                setup_huggingface_for_online()

            # Set up progress callback and tracker
            # If cached: filter out non-download progress (like "Segment 1/1" during generation)
            # If not cached: report all progress (we're actually downloading)
            progress_callback = create_hf_progress_callback(model_name, progress_manager)
            tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)

            # Patch tqdm BEFORE importing qwen_tts
            tracker_context = tracker.patch_download()
            tracker_context.__enter__()

            # Import qwen_tts
            from qwen_tts import Qwen3TTSModel

            print(f"Loading TTS model {model_size} on {self.device}...")

            # Only track download progress if model is NOT cached
            if not is_cached:
                # Start tracking download task
                task_manager.start_download(model_name)

                # Initialize progress state so SSE endpoint has initial data to send
                progress_manager.update_progress(
                    model_name=model_name,
                    current=0,
                    total=0,  # Will be updated once actual total is known
                    filename="Connecting to HuggingFace...",
                    status="downloading",
                )

            # Load the model
            try:
                # When loading from local path, no need for download kwargs
                # When loading from HuggingFace Hub, use download kwargs
                if is_cached:
                    # Load directly from local path - no network access
                    if self.device == "cpu":
                        self.model = Qwen3TTSModel.from_pretrained(
                            load_path,
                            torch_dtype=torch.float32,
                            low_cpu_mem_usage=False,
                        )
                    else:
                        self.model = Qwen3TTSModel.from_pretrained(
                            load_path,
                            device_map=self.device,
                            torch_dtype=torch.bfloat16,
                        )
                else:
                    # Load from HuggingFace Hub - will download
                    download_kwargs = get_model_download_kwargs(model_name, is_cached)
                    if self.device == "cpu":
                        self.model = Qwen3TTSModel.from_pretrained(
                            load_path,
                            torch_dtype=torch.float32,
                            low_cpu_mem_usage=False,
                            **download_kwargs
                        )
                    else:
                        self.model = Qwen3TTSModel.from_pretrained(
                            load_path,
                            device_map=self.device,
                            torch_dtype=torch.bfloat16,
                            **download_kwargs
                        )
            finally:
                # Exit the patch context
                tracker_context.__exit__(None, None, None)

            
            # Only mark download as complete if we were tracking it
            if not is_cached:
                progress_manager.mark_complete(model_name)
                task_manager.complete_download(model_name)
            
            self._current_model_size = model_size
            self.model_size = model_size
            
            print(f"TTS model {model_size} loaded successfully")

        except ImportError as e:
            print(f"Error: qwen_tts package not found. Install with: pip install git+https://github.com/QwenLM/Qwen3-TTS.git")
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            model_name = f"qwen-tts-{model_size}"
            progress_manager.mark_error(model_name, str(e))
            task_manager.error_download(model_name, str(e))
            raise
        except Exception as e:
            error_msg = str(e)
            print(f"Error loading TTS model: {error_msg}")

            # 检测离线模式错误并提供更友好的提示
            if "offline mode" in error_msg.lower() or "cannot reach" in error_msg.lower():
                print(f"\n[提示] 模型文件已缓存,但 qwen_tts 库需要网络连接来验证模型。")
                print(f"[提示] 请尝试以下解决方案:")
                print(f"  1. 连接网络后重试(推荐)")
                print(f"  2. 设置环境变量 HF_HUB_OFFLINE=0 后重试")
                print(f"  3. 检查模型缓存是否完整")
            else:
                print(f"Tip: The model will be automatically downloaded from HuggingFace Hub on first use.")

            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            model_name = f"qwen-tts-{model_size}"
            progress_manager.mark_error(model_name, error_msg)
            task_manager.error_download(model_name, error_msg)
            raise
    
    def unload_model(self):
        """Unload the model to free memory."""
        if self.model is not None:
            del self.model
            self.model = None
            self._current_model_size = None
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            print("TTS model unloaded")
    
    async def create_voice_prompt(
        self,
        audio_path: str,
        reference_text: str,
        use_cache: bool = True,
    ) -> Tuple[dict, bool]:
        """
        Create voice prompt from reference audio.
        
        Args:
            audio_path: Path to reference audio file
            reference_text: Transcript of reference audio
            use_cache: Whether to use cached prompt if available
            
        Returns:
            Tuple of (voice_prompt_dict, was_cached)
        """
        await self.load_model_async(None)
        
        # Check cache if enabled
        if use_cache:
            cache_key = get_cache_key(audio_path, reference_text)
            cached_prompt = get_cached_voice_prompt(cache_key)
            if cached_prompt is not None:
                # Cache stores as torch.Tensor but actual prompt is dict
                # Convert if needed
                if isinstance(cached_prompt, dict):
                    # For PyTorch backend, the dict should contain tensors, not file paths
                    # So we can safely return it
                    return cached_prompt, True
                elif isinstance(cached_prompt, torch.Tensor):
                    # Legacy cache format - convert to dict
                    # This shouldn't happen in practice, but handle it
                    return {"prompt": cached_prompt}, True
        
        def _create_prompt_sync():
            """Run synchronous voice prompt creation in thread pool."""
            return self.model.create_voice_clone_prompt(
                ref_audio=str(audio_path),
                ref_text=reference_text,
                x_vector_only_mode=False,
            )
        
        # Run blocking operation in thread pool
        voice_prompt_items = await asyncio.to_thread(_create_prompt_sync)
        
        # Cache if enabled
        if use_cache:
            cache_key = get_cache_key(audio_path, reference_text)
            cache_voice_prompt(cache_key, voice_prompt_items)
        
        return voice_prompt_items, False
    
    async def combine_voice_prompts(
        self,
        audio_paths: List[str],
        reference_texts: List[str],
    ) -> Tuple[np.ndarray, str]:
        """
        Combine multiple reference samples for better quality.
        
        Args:
            audio_paths: List of audio file paths
            reference_texts: List of reference texts
            
        Returns:
            Tuple of (combined_audio, combined_text)
        """
        combined_audio = []
        
        for audio_path in audio_paths:
            audio, sr = load_audio(audio_path)
            audio = normalize_audio(audio)
            combined_audio.append(audio)
        
        # Concatenate audio
        mixed = np.concatenate(combined_audio)
        mixed = normalize_audio(mixed)
        
        # Combine texts
        combined_text = " ".join(reference_texts)
        
        return mixed, combined_text
    
    async def generate(
        self,
        text: str,
        voice_prompt: dict,
        language: str = "en",
        seed: Optional[int] = None,
        instruct: Optional[str] = None,
    ) -> Tuple[np.ndarray, int]:
        """
        Generate audio from text using voice prompt.

        Args:
            text: Text to synthesize
            voice_prompt: Voice prompt dictionary from create_voice_prompt
            language: Language code (en or zh)
            seed: Random seed for reproducibility
            instruct: Natural language instruction for speech delivery control

        Returns:
            Tuple of (audio_array, sample_rate)
        """
        # Load model
        await self.load_model_async(None)

        def _generate_sync():
            """Run synchronous generation in thread pool."""
            # Set seed if provided
            if seed is not None:
                torch.manual_seed(seed)
                if torch.cuda.is_available():
                    torch.cuda.manual_seed(seed)

            # Generate audio - this is the blocking operation
            wavs, sample_rate = self.model.generate_voice_clone(
                text=text,
                voice_clone_prompt=voice_prompt,
                instruct=instruct,
            )
            return wavs[0], sample_rate

        # Run blocking inference in thread pool to avoid blocking event loop
        audio, sample_rate = await asyncio.to_thread(_generate_sync)

        return audio, sample_rate


class PyTorchSTTBackend:
    """PyTorch-based STT backend using Whisper."""
    
    def __init__(self, model_size: str = "base"):
        self.model = None
        self.processor = None
        self.model_size = model_size
        self.device = self._get_device()
    
    def _get_device(self) -> str:
        """Get the best available device."""
        if torch.cuda.is_available():
            return "cuda"
        # Intel Arc / Intel Xe GPU via intel-extension-for-pytorch (IPEX)
        try:
            import intel_extension_for_pytorch  # noqa: F401
            if hasattr(torch, 'xpu') and torch.xpu.is_available():
                return "xpu"
        except ImportError:
            pass
        # Any GPU on Windows via DirectML (torch-directml)
        try:
            import torch_directml
            if torch_directml.device_count() > 0:
                return torch_directml.device(0)
        except ImportError:
            pass
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            return "cpu"  # MPS disabled for stability
        return "cpu"
    
    def is_loaded(self) -> bool:
        """Check if model is loaded."""
        return self.model is not None
    
    def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
        """
        Get the local snapshot path if the Whisper model is fully cached.

        Args:
            model_size: Model size to check

        Returns:
            Path to local snapshot if fully cached, None otherwise
        """
        try:
            from huggingface_hub import constants as hf_constants
            model_id = f"openai/whisper-{model_size}"
            repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))

            if not repo_cache.exists():
                return None

            # Check for .incomplete files - if any exist, download is still in progress
            blobs_dir = repo_cache / "blobs"
            if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
                return None

            # Check that actual model weight files exist in snapshots
            snapshots_dir = repo_cache / "snapshots"
            if not snapshots_dir.exists():
                return None

            # Get the latest snapshot (by modification time)
            snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
            if not snapshot_dirs:
                return None

            latest_snapshot = snapshot_dirs[0]

            # Check for model weights (actual files, not just symlinks)
            has_weights = (
                any(latest_snapshot.rglob("*.safetensors")) or
                any(latest_snapshot.rglob("*.bin"))
            )
            if not has_weights:
                return None

            # Check for config.json
            if not (latest_snapshot / "config.json").exists():
                return None

            return latest_snapshot
        except Exception as e:
            print(f"[_get_local_snapshot_path] Error: {e}")
            return None

    def _is_model_cached(self, model_size: str) -> bool:
        """
        Check if the Whisper model is already cached locally AND fully downloaded.

        Args:
            model_size: Model size to check

        Returns:
            True if model is fully cached, False if missing or incomplete
        """
        local_path = self._get_local_snapshot_path(model_size)
        if local_path:
            print(f"[_is_model_cached] Whisper model {model_size} is fully cached at {local_path}")
        else:
            print(f"[_is_model_cached] Whisper model {model_size} is not cached")
        return local_path is not None
    
    async def load_model_async(self, model_size: Optional[str] = None):
        """
        Lazy load the Whisper model.

        Args:
            model_size: Model size (tiny, base, small, medium, large)
        """
        print(f"[DEBUG] load_model_async called with size: {model_size}")
        if model_size is None:
            model_size = self.model_size

        print(f"[DEBUG] Model already loaded? {self.model is not None}, current size: {self.model_size}, requested: {model_size}")
        if self.model is not None and self.model_size == model_size:
            print(f"[DEBUG] Early return - model already loaded")
            return

        print(f"[DEBUG] Calling asyncio.to_thread for _load_model_sync")
        # Run blocking load in thread pool
        await asyncio.to_thread(self._load_model_sync, model_size)
        print(f"[DEBUG] asyncio.to_thread completed")
    
    # Alias for compatibility
    load_model = load_model_async
    
    def _load_model_sync(self, model_size: str):
        """Synchronous model loading."""
        print(f"[DEBUG] _load_model_sync called for Whisper {model_size}")
        try:
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            progress_model_name = f"whisper-{model_size}"

            # Get local snapshot path if model is cached
            local_snapshot_path = self._get_local_snapshot_path(model_size)
            is_cached = local_snapshot_path is not None

            # Get model ID for HuggingFace Hub (used for downloading)
            model_id = f"openai/whisper-{model_size}"

            # Determine the path to use for loading
            # If cached, use local snapshot path directly to avoid any network access
            # If not cached, use HuggingFace Hub ID to download
            load_path = str(local_snapshot_path) if is_cached else model_id

            if is_cached:
                print(f"[Whisper] Loading model {model_size} from local cache: {load_path}")
            else:
                print(f"[Whisper] Model {model_size} not cached, will download from HuggingFace Hub")
                setup_huggingface_for_online()

            # Set up progress callback and tracker
            # If cached: filter out non-download progress
            # If not cached: report all progress (we're actually downloading)
            progress_callback = create_hf_progress_callback(progress_model_name, progress_manager)
            tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)

            # Patch tqdm BEFORE importing transformers
            print("[DEBUG] Starting tqdm patch BEFORE transformers import")
            tracker_context = tracker.patch_download()
            tracker_context.__enter__()
            print("[DEBUG] tqdm patched, now importing transformers")

            # Import transformers
            from transformers import WhisperProcessor, WhisperForConditionalGeneration

            print(f"[DEBUG] Model name: {model_id}")

            print(f"Loading Whisper model {model_size} on {self.device}...")

            # Only track download progress if model is NOT cached
            if not is_cached:
                # Start tracking download task
                task_manager.start_download(progress_model_name)

                # Initialize progress state so SSE endpoint has initial data to send
                progress_manager.update_progress(
                    model_name=progress_model_name,
                    current=0,
                    total=0,  # Will be updated once actual total is known
                    filename="Connecting to HuggingFace...",
                    status="downloading",
                )

            # Load models (tqdm is patched, but filters out non-download progress)
            try:
                # When loading from local path, no need for download kwargs
                # When loading from HuggingFace Hub, use download kwargs
                if is_cached:
                    # Load directly from local path - no network access
                    self.processor = WhisperProcessor.from_pretrained(load_path)
                    self.model = WhisperForConditionalGeneration.from_pretrained(load_path)
                else:
                    # Load from HuggingFace Hub - will download
                    download_kwargs = get_model_download_kwargs(progress_model_name, is_cached)
                    self.processor = WhisperProcessor.from_pretrained(load_path, **download_kwargs)
                    self.model = WhisperForConditionalGeneration.from_pretrained(load_path, **download_kwargs)
            finally:
                # Exit the patch context
                tracker_context.__exit__(None, None, None)
            
            # Only mark download as complete if we were tracking it
            if not is_cached:
                progress_manager.mark_complete(progress_model_name)
                task_manager.complete_download(progress_model_name)
            
            self.model.to(self.device)
            self.model_size = model_size
            
            print(f"Whisper model {model_size} loaded successfully")
            
        except Exception as e:
            print(f"Error loading Whisper model: {e}")
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            progress_model_name = f"whisper-{model_size}"
            progress_manager.mark_error(progress_model_name, str(e))
            task_manager.error_download(progress_model_name, str(e))
            raise
    
    def unload_model(self):
        """Unload the model to free memory."""
        if self.model is not None:
            del self.model
            del self.processor
            self.model = None
            self.processor = None
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            print("Whisper model unloaded")
    
    async def transcribe(
        self,
        audio_path: str,
        language: Optional[str] = None,
    ) -> str:
        """
        Transcribe audio to text.
        
        Args:
            audio_path: Path to audio file
            language: Optional language hint (en or zh)
            
        Returns:
            Transcribed text
        """
        await self.load_model_async(None)
        
        def _transcribe_sync():
            """Run synchronous transcription in thread pool."""
            # Load audio
            audio, sr = load_audio(audio_path, sample_rate=16000)
            
            # Process audio
            inputs = self.processor(
                audio,
                sampling_rate=16000,
                return_tensors="pt",
            )
            inputs = inputs.to(self.device)
            
            # Set language if provided
            forced_decoder_ids = None
            if language:
                # Support all languages from frontend: en, zh, ja, ko, de, fr, ru, pt, es, it
                # Whisper supports these and many more
                forced_decoder_ids = self.processor.get_decoder_prompt_ids(
                    language=language,
                    task="transcribe",
                )
            
            # Generate transcription
            with torch.no_grad():
                predicted_ids = self.model.generate(
                    inputs["input_features"],
                    forced_decoder_ids=forced_decoder_ids,
                )
            
            # Decode
            transcription = self.processor.batch_decode(
                predicted_ids,
                skip_special_tokens=True,
            )[0]
            
            return transcription.strip()
        
        # Run blocking transcription in thread pool
        return await asyncio.to_thread(_transcribe_sync)

6. backend/backends/mlx_backend.py

"""
MLX backend implementation for TTS and STT using mlx-audio.
"""

from typing import Optional, List, Tuple
import asyncio
import numpy as np
from pathlib import Path

from . import TTSBackend, STTBackend
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.audio import normalize_audio, load_audio
from ..utils.progress import get_progress_manager
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager
from ..utils.hf_config import (
    get_model_download_kwargs,
    setup_huggingface_for_online,
)


class MLXTTSBackend:
    """MLX-based TTS backend using mlx-audio."""
    
    def __init__(self, model_size: str = "1.7B"):
        self.model = None
        self.model_size = model_size
        self._current_model_size = None
    
    def is_loaded(self) -> bool:
        """Check if model is loaded."""
        return self.model is not None
    
    def _get_model_path(self, model_size: str) -> str:
        """
        Get the MLX model path.
        
        Args:
            model_size: Model size (1.7B or 0.6B)
            
        Returns:
            HuggingFace Hub model ID for MLX
        """
        # MLX model mapping
        mlx_model_map = {
            "1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
            # 0.6B not yet converted to MLX format
            "0.6B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",  # Fallback to 1.7B
        }
        
        if model_size not in mlx_model_map:
            raise ValueError(f"Unknown model size: {model_size}")
        
        hf_model_id = mlx_model_map[model_size]
        print(f"Will download MLX model from HuggingFace Hub: {hf_model_id}")
        
        return hf_model_id
    
    def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
        """
        Get the local snapshot path if the model is fully cached.

        Args:
            model_size: Model size to check

        Returns:
            Path to local snapshot if fully cached, None otherwise
        """
        try:
            from huggingface_hub import constants as hf_constants
            model_id = self._get_model_path(model_size)
            repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))

            if not repo_cache.exists():
                return None

            # Check for .incomplete files - if any exist, download is still in progress
            blobs_dir = repo_cache / "blobs"
            if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
                return None

            # Check that actual model weight files exist in snapshots
            snapshots_dir = repo_cache / "snapshots"
            if not snapshots_dir.exists():
                return None

            # Get the latest snapshot (by modification time)
            snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
            if not snapshot_dirs:
                return None

            latest_snapshot = snapshot_dirs[0]

            # Check for model weights (actual files, not just symlinks)
            has_weights = (
                any(latest_snapshot.rglob("*.safetensors")) or
                any(latest_snapshot.rglob("*.bin")) or
                any(latest_snapshot.rglob("*.npz"))
            )
            if not has_weights:
                return None

            # Check for config.json
            if not (latest_snapshot / "config.json").exists():
                return None

            return latest_snapshot
        except Exception as e:
            print(f"[_get_local_snapshot_path] Error: {e}")
            return None

    def _is_model_cached(self, model_size: str) -> bool:
        """
        Check if the model is already cached locally AND fully downloaded.

        Args:
            model_size: Model size to check

        Returns:
            True if model is fully cached, False if missing or incomplete
        """
        local_path = self._get_local_snapshot_path(model_size)
        if local_path:
            print(f"[_is_model_cached] Model {model_size} is fully cached at {local_path}")
        else:
            print(f"[_is_model_cached] Model {model_size} is not cached")
        return local_path is not None
    
    async def load_model_async(self, model_size: Optional[str] = None):
        """
        Lazy load the MLX TTS model.
        
        Args:
            model_size: Model size to load (1.7B or 0.6B)
        """
        if model_size is None:
            model_size = self.model_size
            
        # If already loaded with correct size, return
        if self.model is not None and self._current_model_size == model_size:
            return
        
        # Unload existing model if different size requested
        if self.model is not None and self._current_model_size != model_size:
            self.unload_model()
        
        # Run blocking load in thread pool
        await asyncio.to_thread(self._load_model_sync, model_size)
    
    # Alias for compatibility
    load_model = load_model_async
    
    def _load_model_sync(self, model_size: str):
        """Synchronous model loading."""
        try:
            # Set up progress tracking
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            model_name = f"qwen-tts-{model_size}"

            # Get local snapshot path if model is cached
            local_snapshot_path = self._get_local_snapshot_path(model_size)
            is_cached = local_snapshot_path is not None

            # Get model ID for HuggingFace Hub (used for downloading)
            model_id = self._get_model_path(model_size)

            # Determine the path to use for loading
            # If cached, use local snapshot path directly to avoid any network access
            # If not cached, use HuggingFace Hub ID to download
            load_path = str(local_snapshot_path) if is_cached else model_id

            if is_cached:
                print(f"[MLX TTS] Loading model {model_size} from local cache: {load_path}")
            else:
                print(f"[MLX TTS] Model {model_size} not cached, will download from HuggingFace Hub")
                setup_huggingface_for_online()

            # Set up progress callback
            # If cached: filter out non-download progress
            # If not cached: report all progress (we're actually downloading)
            progress_callback = create_hf_progress_callback(model_name, progress_manager)
            tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)

            print(f"Loading MLX TTS model {model_size}...")

            # Only track download progress if model is NOT cached
            if not is_cached:
                # Start tracking download task
                task_manager.start_download(model_name)

                # Initialize progress state so SSE endpoint has initial data to send
                # This provides immediate feedback while HuggingFace fetches metadata
                progress_manager.update_progress(
                    model_name=model_name,
                    current=0,
                    total=0,  # Will be updated once actual total is known
                    filename="Connecting to HuggingFace...",
                    status="downloading",
                )

            # IMPORTANT: Patch tqdm BEFORE importing mlx_audio
            # Otherwise mlx_audio caches reference to original tqdm
            tracker_context = tracker.patch_download()
            tracker_context.__enter__()

            # Import mlx_audio AFTER patching tqdm
            from mlx_audio.tts import load

            # Load MLX model (downloads automatically)
            try:
                # When loading from local path, no need for download kwargs
                # When loading from HuggingFace Hub, use download kwargs
                if is_cached:
                    # Load directly from local path - no network access
                    self.model = load(load_path)
                else:
                    # Load from HuggingFace Hub - will download
                    self.model = load(load_path)
            finally:
                # Exit the patch context
                tracker_context.__exit__(None, None, None)
            
            # Only mark download as complete if we were tracking it
            if not is_cached:
                progress_manager.mark_complete(model_name)
                task_manager.complete_download(model_name)
            
            self._current_model_size = model_size
            self.model_size = model_size
            
            print(f"MLX TTS model {model_size} loaded successfully")
            
        except ImportError as e:
            print(f"Error: mlx_audio package not found. Install with: pip install mlx-audio")
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            model_name = f"qwen-tts-{model_size}"
            progress_manager.mark_error(model_name, str(e))
            task_manager.error_download(model_name, str(e))
            raise
        except Exception as e:
            print(f"Error loading MLX TTS model: {e}")
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            model_name = f"qwen-tts-{model_size}"
            progress_manager.mark_error(model_name, str(e))
            task_manager.error_download(model_name, str(e))
            raise
    
    def unload_model(self):
        """Unload the model to free memory."""
        if self.model is not None:
            del self.model
            self.model = None
            self._current_model_size = None
            print("MLX TTS model unloaded")
    
    async def create_voice_prompt(
        self,
        audio_path: str,
        reference_text: str,
        use_cache: bool = True,
    ) -> Tuple[dict, bool]:
        """
        Create voice prompt from reference audio.
        
        MLX backend stores voice prompt as a dict with audio path and text.
        The actual voice prompt processing happens during generation.
        
        Args:
            audio_path: Path to reference audio file
            reference_text: Transcript of reference audio
            use_cache: Whether to use cached prompt if available
            
        Returns:
            Tuple of (voice_prompt_dict, was_cached)
        """
        await self.load_model_async(None)
        
        # Check cache if enabled
        if use_cache:
            cache_key = get_cache_key(audio_path, reference_text)
            cached_prompt = get_cached_voice_prompt(cache_key)
            if cached_prompt is not None:
                # Return cached prompt (should be dict format)
                if isinstance(cached_prompt, dict):
                    # Validate that the cached audio file still exists
                    cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
                    if cached_audio_path and Path(cached_audio_path).exists():
                        return cached_prompt, True
                    else:
                        # Cached file no longer exists, invalidate cache
                        print(f"Cached audio file not found: {cached_audio_path}, regenerating prompt")
        
        # MLX voice prompt format - store audio path and text
        # The model will process this during generation
        voice_prompt_items = {
            "ref_audio": str(audio_path),
            "ref_text": reference_text,
        }
        
        # Cache if enabled
        if use_cache:
            cache_key = get_cache_key(audio_path, reference_text)
            cache_voice_prompt(cache_key, voice_prompt_items)
        
        return voice_prompt_items, False
    
    async def combine_voice_prompts(
        self,
        audio_paths: List[str],
        reference_texts: List[str],
    ) -> Tuple[np.ndarray, str]:
        """
        Combine multiple reference samples for better quality.
        
        Args:
            audio_paths: List of audio file paths
            reference_texts: List of reference texts
            
        Returns:
            Tuple of (combined_audio, combined_text)
        """
        combined_audio = []
        
        for audio_path in audio_paths:
            audio, sr = load_audio(audio_path)
            audio = normalize_audio(audio)
            combined_audio.append(audio)
        
        # Concatenate audio
        mixed = np.concatenate(combined_audio)
        mixed = normalize_audio(mixed)
        
        # Combine texts
        combined_text = " ".join(reference_texts)
        
        return mixed, combined_text
    
    async def generate(
        self,
        text: str,
        voice_prompt: dict,
        language: str = "en",
        seed: Optional[int] = None,
        instruct: Optional[str] = None,
    ) -> Tuple[np.ndarray, int]:
        """
        Generate audio from text using voice prompt.

        Args:
            text: Text to synthesize
            voice_prompt: Voice prompt dictionary with ref_audio and ref_text
            language: Language code (en or zh) - may not be fully supported by MLX
            seed: Random seed for reproducibility
            instruct: Natural language instruction (may not be supported by MLX)

        Returns:
            Tuple of (audio_array, sample_rate)
        """
        await self.load_model_async(None)

        print(f"Generating audio for text: {text}")

        def _generate_sync():
            """Run synchronous generation in thread pool."""
            # MLX generate() returns a generator yielding GenerationResult objects
            audio_chunks = []
            sample_rate = 24000
            
            # Set seed if provided (MLX uses numpy random)
            if seed is not None:
                import mlx.core as mx
                np.random.seed(seed)
                mx.random.seed(seed)
            
            # Extract voice prompt info
            ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
            ref_text = voice_prompt.get("ref_text", "")
            
            # Validate that the audio file exists
            if ref_audio and not Path(ref_audio).exists():
                print(f"Warning: Audio file not found: {ref_audio}")
                print("This may be due to a cached voice prompt referencing a deleted temp file.")
                print("Regenerating without voice prompt.")
                ref_audio = None
            
            # Check if model supports voice cloning via generate method
            # MLX API may support ref_audio parameter directly
            try:
                # Try with voice cloning parameters if supported
                if ref_audio:
                    # Check if generate accepts ref_audio parameter
                    import inspect
                    sig = inspect.signature(self.model.generate)
                    if "ref_audio" in sig.parameters:
                        # Generate with voice cloning
                        for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text):
                            audio_chunks.append(np.array(result.audio))
                            sample_rate = result.sample_rate
                    else:
                        # Fallback: generate without voice cloning
                        for result in self.model.generate(text):
                            audio_chunks.append(np.array(result.audio))
                            sample_rate = result.sample_rate
                else:
                    # No voice prompt, generate normally
                    for result in self.model.generate(text):
                        audio_chunks.append(np.array(result.audio))
                        sample_rate = result.sample_rate
            except Exception as e:
                # If voice cloning fails, try without it
                print(f"Warning: Voice cloning failed, generating without voice prompt: {e}")
                for result in self.model.generate(text):
                    audio_chunks.append(np.array(result.audio))
                    sample_rate = result.sample_rate
            
            # Concatenate all chunks
            if audio_chunks:
                audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
            else:
                # Fallback: empty audio
                audio = np.array([], dtype=np.float32)
            
            return audio, sample_rate

        # Run blocking inference in thread pool
        audio, sample_rate = await asyncio.to_thread(_generate_sync)

        return audio, sample_rate


class MLXSTTBackend:
    """MLX-based STT backend using mlx-audio Whisper."""
    
    def __init__(self, model_size: str = "base"):
        self.model = None
        self.model_size = model_size
    
    def is_loaded(self) -> bool:
        """Check if model is loaded."""
        return self.model is not None
    
    def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
        """
        Get local snapshot path if Whisper model is fully cached.

        Args:
            model_size: Model size to check

        Returns:
            Path to local snapshot if fully cached, None otherwise
        """
        try:
            from huggingface_hub import constants as hf_constants
            model_id = f"openai/whisper-{model_size}"
            repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))

            if not repo_cache.exists():
                return None

            # Check for .incomplete files - if any exist, download is still in progress
            blobs_dir = repo_cache / "blobs"
            if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
                return None

            # Check that actual model weight files exist in snapshots
            snapshots_dir = repo_cache / "snapshots"
            if not snapshots_dir.exists():
                return None

            # Get the latest snapshot (by modification time)
            snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
            if not snapshot_dirs:
                return None

            latest_snapshot = snapshot_dirs[0]

            # Check for model weights (actual files, not just symlinks)
            has_weights = (
                any(latest_snapshot.rglob("*.safetensors")) or
                any(latest_snapshot.rglob("*.bin")) or
                any(latest_snapshot.rglob("*.npz"))
            )
            if not has_weights:
                return None

            # Check for config.json
            if not (latest_snapshot / "config.json").exists():
                return None

            return latest_snapshot
        except Exception as e:
            print(f"[_get_local_snapshot_path] Error: {e}")
            return None

    def _is_model_cached(self, model_size: str) -> bool:
        """
        Check if the Whisper model is already cached locally AND fully downloaded.

        Args:
            model_size: Model size to check

        Returns:
            True if model is fully cached, False if missing or incomplete
        """
        local_path = self._get_local_snapshot_path(model_size)
        if local_path:
            print(f"[_is_model_cached] Whisper model {model_size} is fully cached at {local_path}")
        else:
            print(f"[_is_model_cached] Whisper model {model_size} is not cached")
        return local_path is not None
    
    async def load_model_async(self, model_size: Optional[str] = None):
        """
        Lazy load the MLX Whisper model.
        
        Args:
            model_size: Model size (tiny, base, small, medium, large)
        """
        if model_size is None:
            model_size = self.model_size
        
        if self.model is not None and self.model_size == model_size:
            return
        
        # Run blocking load in thread pool
        await asyncio.to_thread(self._load_model_sync, model_size)
    
    # Alias for compatibility
    load_model = load_model_async
    
    def _load_model_sync(self, model_size: str):
        """Synchronous model loading."""
        try:
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            progress_model_name = f"whisper-{model_size}"

            # Get local snapshot path if model is cached
            local_snapshot_path = self._get_local_snapshot_path(model_size)
            is_cached = local_snapshot_path is not None

            # Get model ID for HuggingFace Hub (used for downloading)
            model_id = f"openai/whisper-{model_size}"

            # Determine the path to use for loading
            # If cached, use local snapshot path directly to avoid any network access
            # If not cached, use HuggingFace Hub ID to download
            load_path = str(local_snapshot_path) if is_cached else model_id

            if is_cached:
                print(f"[MLX Whisper] Loading model {model_size} from local cache: {load_path}")
            else:
                print(f"[MLX Whisper] Model {model_size} not cached, will download from HuggingFace Hub")
                setup_huggingface_for_online()

            # Set up progress callback and tracker
            # If cached: filter out non-download progress
            # If not cached: report all progress (we're actually downloading)
            progress_callback = create_hf_progress_callback(progress_model_name, progress_manager)
            tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)

            # Patch tqdm BEFORE importing mlx_audio
            tracker_context = tracker.patch_download()
            tracker_context.__enter__()

            # Import mlx_audio
            from mlx_audio.stt import load

            print(f"Loading MLX Whisper model {model_size}...")

            # Only track download progress if model is NOT cached
            if not is_cached:
                # Start tracking download task
                task_manager.start_download(progress_model_name)

                # Initialize progress state so SSE endpoint has initial data to send
                progress_manager.update_progress(
                    model_name=progress_model_name,
                    current=0,
                    total=0,
                    filename="Connecting to HuggingFace...",
                    status="downloading",
                )

            # Load the model
            try:
                # When loading from local path, no need for download kwargs
                # When loading from HuggingFace Hub, use download kwargs
                if is_cached:
                    # Load directly from local path - no network access
                    self.model = load(load_path)
                else:
                    # Load from HuggingFace Hub - will download
                    self.model = load(load_path)
            finally:
                # Exit the patch context
                tracker_context.__exit__(None, None, None)
            
            # Only mark download as complete if we were tracking it
            if not is_cached:
                progress_manager.mark_complete(progress_model_name)
                task_manager.complete_download(progress_model_name)
            
            self.model_size = model_size
            
            print(f"MLX Whisper model {model_size} loaded successfully")
            
        except ImportError as e:
            print(f"Error: mlx_audio package not found. Install with: pip install mlx-audio")
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            progress_model_name = f"whisper-{model_size}"
            progress_manager.mark_error(progress_model_name, str(e))
            task_manager.error_download(progress_model_name, str(e))
            raise
        except Exception as e:
            print(f"Error loading MLX Whisper model: {e}")
            progress_manager = get_progress_manager()
            task_manager = get_task_manager()
            progress_model_name = f"whisper-{model_size}"
            progress_manager.mark_error(progress_model_name, str(e))
            task_manager.error_download(progress_model_name, str(e))
            raise
    
    def unload_model(self):
        """Unload the model to free memory."""
        if self.model is not None:
            del self.model
            self.model = None
            print("MLX Whisper model unloaded")
    
    async def transcribe(
        self,
        audio_path: str,
        language: Optional[str] = None,
    ) -> str:
        """
        Transcribe audio to text.

        Args:
            audio_path: Path to audio file
            language: Optional language hint (en or zh)

        Returns:
            Transcribed text
        """
        await self.load_model_async(None)

        def _transcribe_sync():
            """Run synchronous transcription in thread pool."""
            # MLX Whisper transcription using generate method
            # The generate method accepts audio path directly
            decode_options = {}
            if language:
                decode_options["language"] = language

            result = self.model.generate(str(audio_path), **decode_options)

            # Extract text from result
            if isinstance(result, str):
                return result.strip()
            elif isinstance(result, dict):
                return result.get("text", "").strip()
            elif hasattr(result, "text"):
                return result.text.strip()
            else:
                return str(result).strip()

        # Run blocking transcription in thread pool
        return await asyncio.to_thread(_transcribe_sync)

三、说明

1.mlx_backend.py中的代码没有经过测试,因为我没有mac电脑。
2.应用开发和发布,对于用户不可控的网络访问是会引起用户不安的,需要从用户角度考虑、慎重对待。
3.代码中加入了https://hf-mirror.com镜像支持
4.代码中禁用了ssl验证,这是为了解决模型下载时ssl验证失败而做的选择,但这存安全风险,参考上述代码的朋友需要注意。
5.Voicebox项目有多个操作系统的预编译可执行程序,我下载了windows版本的Voicebox_0.1.13_x64-setup.exe,安装运行后,因为为ssl证书验证错误,始终无法完成Qwen/Qwen3-TTS-12Hz-1.7B-Base和Qwen/Qwen3-TTS-12Hz-0.6B-Base两个模型的下载,因此这个预编译版的程序也就无法正常使用。无奈下才有了本篇文章。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐