Youtu-2B模型服务封装:Flask+WebUI完整指南

1. 引言:为什么选择Youtu-2B?

如果你正在寻找一个既轻量又聪明的AI对话助手,那么Youtu-2B模型绝对值得你花时间了解一下。想象一下,一个只有20亿参数的“小个子”模型,却能在数学题、编程任务和日常聊天中表现出惊人的能力,而且对硬件要求极低——普通消费级显卡甚至CPU都能流畅运行。

这就是腾讯优图实验室推出的Youtu-LLM-2B模型带来的惊喜。但模型本身只是开始,如何把它变成一个随时可用的服务,才是真正发挥价值的关键。本文将带你一步步完成从模型到服务的完整封装,使用Flask构建稳健的后端API,并搭配一个简洁美观的WebUI界面,让你和你的团队能够轻松地与这个智能助手对话。

无论你是想快速搭建一个内部知识问答系统,还是需要一个编程辅助工具,或是单纯想体验轻量级大模型的魅力,这篇指南都将为你提供清晰的路径。我们不会涉及复杂的理论推导,而是聚焦于“怎么做”——从环境准备到代码实现,从界面设计到部署优化,每个环节都有详细的步骤和可运行的代码示例。

2. 环境准备与快速部署

2.1 系统要求与依赖安装

在开始之前,我们先看看需要准备什么。好消息是,Youtu-2B对硬件的要求相当友好。

基础环境要求:

  • 操作系统:Ubuntu 20.04/22.04或CentOS 7/8(Windows用户建议使用WSL2)
  • Python版本:Python 3.8或3.9
  • 内存:至少8GB RAM
  • 存储空间:10GB可用空间(用于模型文件和依赖包)
  • 显卡:可选但推荐(有GPU会更快)
    • NVIDIA GPU(4GB显存以上)
    • 或仅使用CPU(速度稍慢但完全可用)

一键安装所有依赖:

打开终端,执行以下命令来搭建完整的环境:

# 创建项目目录
mkdir youtu-2b-service && cd youtu-2b-service

# 创建Python虚拟环境(推荐)
python -m venv venv
source venv/bin/activate  # Linux/Mac
# 或 venv\Scripts\activate  # Windows

# 安装PyTorch(根据你的CUDA版本选择)
# 如果没有GPU或CUDA 11.7
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

# 如果有GPU(以CUDA 11.8为例)
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# 安装模型推理和Web框架
pip install transformers>=4.30.0
pip install flask>=2.3.0
pip install flask-cors>=4.0.0

# 安装WebUI相关依赖
pip install gradio>=3.40.0

# 安装其他工具库
pip install requests>=2.31.0
pip install sentencepiece>=0.1.99  # 分词器需要

安装过程可能需要几分钟,取决于你的网络速度。如果遇到下载慢的问题,可以考虑使用国内的镜像源:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple [包名]

2.2 模型下载与验证

环境准备好后,我们需要下载Youtu-2B模型。这里有两种方式:

方式一:直接从Hugging Face下载(推荐)

# download_model.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

# 指定模型路径
model_name = "Tencent-YouTu-Research/Youtu-LLM-2B"
save_path = "./models/youtu-2b"

# 创建保存目录
os.makedirs(save_path, exist_ok=True)

print("开始下载Youtu-2B模型...")
print("这可能需要一些时间,模型大小约4GB")

# 下载tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype="auto"  # 自动选择数据类型
)

# 保存到本地
tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)

print(f"模型已保存到: {save_path}")
print("下载完成!")

方式二:使用预下载的模型文件

如果你已经下载了模型文件,可以直接放到指定目录:

# 假设模型文件在本地路径
mkdir -p ./models/youtu-2b
cp -r /path/to/your/model/files/* ./models/youtu-2b/

验证模型是否正常:

下载完成后,运行一个简单的测试脚本来验证模型能否正常工作:

# test_model.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_path = "./models/youtu-2b"

print("加载模型和分词器...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"  # 自动分配到可用设备
)

# 测试推理
test_prompt = "你好,请介绍一下你自己。"
inputs = tokenizer(test_prompt, return_tensors="pt")

# 将输入移到模型所在的设备
if torch.cuda.is_available():
    inputs = {k: v.cuda() for k, v in inputs.items()}

print("开始推理...")
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        do_sample=True
    )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"测试回复: {response}")
print("模型测试通过!")

如果看到模型正常回复,说明环境配置成功。接下来我们开始构建服务。

3. Flask后端服务封装

3.1 项目结构设计

在开始编码之前,我们先规划一下项目的目录结构。一个好的结构能让代码更清晰,也便于后期维护。

youtu-2b-service/
├── app.py                    # Flask主应用
├── config.py                # 配置文件
├── requirements.txt         # 依赖列表
├── models/
│   └── youtu-2b/           # 模型文件
├── api/
│   ├── __init__.py
│   ├── chat.py             # 聊天API
│   └── health.py           # 健康检查API
├── core/
│   ├── __init__.py
│   ├── model_loader.py     # 模型加载器
│   └── inference.py        # 推理引擎
├── static/                 # 静态文件
│   ├── css/
│   └── js/
├── templates/              # 模板文件
│   └── index.html
└── utils/
    ├── __init__.py
    └── logger.py           # 日志工具

创建这个结构:

mkdir -p api core static/css static/js templates utils
touch app.py config.py requirements.txt
touch api/__init__.py api/chat.py api/health.py
touch core/__init__.py core/model_loader.py core/inference.py
touch utils/__init__.py utils/logger.py
touch templates/index.html

3.2 核心模型加载器

我们先实现模型加载模块,这是整个服务的基础:

# core/model_loader.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from typing import Tuple, Optional

logger = logging.getLogger(__name__)

class ModelLoader:
    """模型加载和管理类"""
    
    def __init__(self, model_path: str = "./models/youtu-2b"):
        self.model_path = model_path
        self.tokenizer = None
        self.model = None
        self.device = None
        
    def load_model(self, use_gpu: bool = True) -> Tuple[bool, str]:
        """
        加载模型和分词器
        
        Args:
            use_gpu: 是否使用GPU
            
        Returns:
            (成功标志, 消息)
        """
        try:
            logger.info(f"开始加载模型,路径: {self.model_path}")
            
            # 确定设备
            if use_gpu and torch.cuda.is_available():
                self.device = torch.device("cuda")
                logger.info(f"使用GPU: {torch.cuda.get_device_name(0)}")
            else:
                self.device = torch.device("cpu")
                logger.info("使用CPU")
            
            # 加载分词器
            logger.info("加载分词器...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path,
                trust_remote_code=True
            )
            
            # 设置pad_token(如果不存在)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # 加载模型
            logger.info("加载模型...")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                trust_remote_code=True,
                torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
                device_map="auto" if self.device.type == "cuda" else None
            )
            
            # 如果使用CPU,需要手动移动模型
            if self.device.type == "cpu":
                self.model = self.model.to(self.device)
            
            # 设置为评估模式
            self.model.eval()
            
            logger.info("模型加载完成!")
            return True, "模型加载成功"
            
        except Exception as e:
            error_msg = f"模型加载失败: {str(e)}"
            logger.error(error_msg)
            return False, error_msg
    
    def get_model_info(self) -> dict:
        """获取模型信息"""
        if self.model is None:
            return {"status": "not_loaded"}
        
        info = {
            "status": "loaded",
            "model_path": self.model_path,
            "device": str(self.device),
            "model_type": type(self.model).__name__,
            "parameters": sum(p.numel() for p in self.model.parameters()),
        }
        
        if self.device.type == "cuda":
            info["gpu_memory"] = f"{torch.cuda.memory_allocated() / 1024**3:.2f} GB"
        
        return info
    
    def unload_model(self):
        """卸载模型释放内存"""
        if self.model is not None:
            del self.model
            self.model = None
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            logger.info("模型已卸载")

3.3 推理引擎实现

接下来实现推理引擎,这是处理用户请求的核心:

# core/inference.py
import torch
import time
from typing import Dict, List, Optional
import logging
from .model_loader import ModelLoader

logger = logging.getLogger(__name__)

class InferenceEngine:
    """推理引擎"""
    
    def __init__(self, model_loader: ModelLoader):
        self.model_loader = model_loader
        self.tokenizer = model_loader.tokenizer
        self.model = model_loader.model
        self.device = model_loader.device
        
    def generate_response(
        self,
        prompt: str,
        max_length: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        repetition_penalty: float = 1.1,
        **kwargs
    ) -> Dict:
        """
        生成回复
        
        Args:
            prompt: 输入文本
            max_length: 最大生成长度
            temperature: 温度参数(控制随机性)
            top_p: 核采样参数
            repetition_penalty: 重复惩罚
            
        Returns:
            包含回复和元数据的字典
        """
        if self.model is None or self.tokenizer is None:
            return {
                "response": "模型未加载,请稍后重试",
                "error": "model_not_loaded"
            }
        
        try:
            start_time = time.time()
            
            # 编码输入
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length
            )
            
            # 移动到设备
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # 生成配置
            generate_kwargs = {
                "max_new_tokens": min(1024, max_length),
                "temperature": max(0.1, min(2.0, temperature)),
                "top_p": max(0.1, min(1.0, top_p)),
                "repetition_penalty": max(1.0, min(2.0, repetition_penalty)),
                "do_sample": temperature > 0.1,
                "pad_token_id": self.tokenizer.pad_token_id,
                "eos_token_id": self.tokenizer.eos_token_id,
            }
            
            # 更新额外参数
            generate_kwargs.update(kwargs)
            
            # 生成回复
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    **generate_kwargs
                )
            
            # 解码输出
            response = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            )
            
            # 计算耗时
            inference_time = time.time() - start_time
            
            # 统计信息
            input_tokens = inputs["input_ids"].shape[1]
            output_tokens = outputs.shape[1] - input_tokens
            tokens_per_second = output_tokens / inference_time if inference_time > 0 else 0
            
            logger.info(f"推理完成: {input_tokens} -> {output_tokens} tokens, "
                       f"耗时: {inference_time:.2f}s, "
                       f"速度: {tokens_per_second:.1f} tokens/s")
            
            return {
                "response": response.strip(),
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "total_tokens": input_tokens + output_tokens,
                "inference_time": round(inference_time, 3),
                "tokens_per_second": round(tokens_per_second, 1),
                "model": "Youtu-LLM-2B",
                "success": True
            }
            
        except Exception as e:
            error_msg = f"推理失败: {str(e)}"
            logger.error(error_msg)
            return {
                "response": "生成回复时出现错误",
                "error": error_msg,
                "success": False
            }
    
    def batch_generate(
        self,
        prompts: List[str],
        **kwargs
    ) -> List[Dict]:
        """批量生成回复"""
        results = []
        for prompt in prompts:
            result = self.generate_response(prompt, **kwargs)
            results.append(result)
        return results

3.4 Flask API实现

现在我们来创建Flask API端点:

# api/chat.py
from flask import Blueprint, request, jsonify
import logging
from typing import Dict, Any
import json

from core.inference import InferenceEngine

# 创建蓝图
chat_bp = Blueprint('chat', __name__)
logger = logging.getLogger(__name__)

# 全局推理引擎实例
_inference_engine = None

def init_inference_engine(model_loader):
    """初始化推理引擎"""
    global _inference_engine
    if _inference_engine is None:
        _inference_engine = InferenceEngine(model_loader)
    return _inference_engine

@chat_bp.route('/chat', methods=['POST'])
def chat():
    """
    聊天API端点
    
    请求格式:
    {
        "prompt": "你的问题",
        "max_length": 512,
        "temperature": 0.7,
        "top_p": 0.9,
        "stream": false
    }
    """
    try:
        # 解析请求数据
        if request.is_json:
            data = request.get_json()
        else:
            data = request.form.to_dict()
        
        # 获取必需参数
        prompt = data.get('prompt', '').strip()
        if not prompt:
            return jsonify({
                "error": "prompt参数不能为空",
                "success": False
            }), 400
        
        # 获取可选参数
        max_length = int(data.get('max_length', 512))
        temperature = float(data.get('temperature', 0.7))
        top_p = float(data.get('top_p', 0.9))
        stream = data.get('stream', False)
        
        # 参数验证
        if max_length < 1 or max_length > 4096:
            return jsonify({
                "error": "max_length必须在1-4096之间",
                "success": False
            }), 400
        
        if temperature < 0.1 or temperature > 2.0:
            return jsonify({
                "error": "temperature必须在0.1-2.0之间",
                "success": False
            }), 400
        
        # 生成回复
        if _inference_engine is None:
            return jsonify({
                "error": "推理引擎未初始化",
                "success": False
            }), 503
        
        # 流式输出(简化版)
        if stream:
            # 这里可以实现真正的流式输出
            # 为简化,我们先返回完整结果
            result = _inference_engine.generate_response(
                prompt=prompt,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p
            )
            return jsonify(result)
        else:
            result = _inference_engine.generate_response(
                prompt=prompt,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p
            )
            return jsonify(result)
            
    except ValueError as e:
        logger.error(f"参数错误: {str(e)}")
        return jsonify({
            "error": f"参数格式错误: {str(e)}",
            "success": False
        }), 400
    except Exception as e:
        logger.error(f"处理请求时出错: {str(e)}")
        return jsonify({
            "error": f"服务器内部错误: {str(e)}",
            "success": False
        }), 500

@chat_bp.route('/batch_chat', methods=['POST'])
def batch_chat():
    """批量聊天API"""
    try:
        data = request.get_json()
        prompts = data.get('prompts', [])
        
        if not isinstance(prompts, list) or len(prompts) == 0:
            return jsonify({
                "error": "prompts必须是非空数组",
                "success": False
            }), 400
        
        if len(prompts) > 10:  # 限制批量大小
            return jsonify({
                "error": "批量请求最多支持10个prompt",
                "success": False
            }), 400
        
        # 获取参数
        max_length = int(data.get('max_length', 512))
        temperature = float(data.get('temperature', 0.7))
        
        if _inference_engine is None:
            return jsonify({
                "error": "推理引擎未初始化",
                "success": False
            }), 503
        
        results = _inference_engine.batch_generate(
            prompts=prompts,
            max_length=max_length,
            temperature=temperature
        )
        
        return jsonify({
            "results": results,
            "count": len(results),
            "success": True
        })
        
    except Exception as e:
        logger.error(f"批量处理失败: {str(e)}")
        return jsonify({
            "error": f"处理失败: {str(e)}",
            "success": False
        }), 500
# api/health.py
from flask import Blueprint, jsonify
import psutil
import torch

health_bp = Blueprint('health', __name__)

@health_bp.route('/health', methods=['GET'])
def health_check():
    """健康检查端点"""
    try:
        # 基础系统信息
        cpu_percent = psutil.cpu_percent(interval=0.1)
        memory = psutil.virtual_memory()
        
        health_data = {
            "status": "healthy",
            "service": "youtu-2b-api",
            "timestamp": psutil.boot_time(),
            "system": {
                "cpu_percent": cpu_percent,
                "memory_percent": memory.percent,
                "memory_available_gb": round(memory.available / (1024**3), 2),
                "memory_total_gb": round(memory.total / (1024**3), 2)
            }
        }
        
        # GPU信息(如果可用)
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.memory_allocated()
            gpu_memory_reserved = torch.cuda.memory_reserved()
            health_data["gpu"] = {
                "available": True,
                "device_count": torch.cuda.device_count(),
                "current_device": torch.cuda.current_device(),
                "device_name": torch.cuda.get_device_name(0),
                "memory_allocated_gb": round(gpu_memory / (1024**3), 2),
                "memory_reserved_gb": round(gpu_memory_reserved / (1024**3), 2)
            }
        else:
            health_data["gpu"] = {"available": False}
        
        return jsonify(health_data)
        
    except Exception as e:
        return jsonify({
            "status": "unhealthy",
            "error": str(e)
        }), 500

@health_bp.route('/info', methods=['GET'])
def model_info():
    """模型信息端点"""
    # 这里需要从全局获取模型信息
    # 为了简化,我们先返回基础信息
    info = {
        "model": "Youtu-LLM-2B",
        "provider": "Tencent YouTu Research",
        "parameters": "2B",
        "context_length": 4096,
        "languages": ["中文", "英文"],
        "capabilities": ["对话", "代码生成", "数学推理", "文案创作"]
    }
    return jsonify(info)

3.5 Flask主应用

最后,我们把所有部分整合到Flask主应用中:

# app.py
from flask import Flask, render_template, jsonify
from flask_cors import CORS
import logging
from datetime import datetime
import os

from core.model_loader import ModelLoader
from api.chat import chat_bp, init_inference_engine
from api.health import health_bp

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f"logs/app_{datetime.now().strftime('%Y%m%d')}.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# 创建Flask应用
app = Flask(__name__)
CORS(app)  # 允许跨域请求

# 配置
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB最大请求
app.config['JSON_AS_ASCII'] = False  # 支持中文

# 创建必要的目录
os.makedirs("logs", exist_ok=True)
os.makedirs("static/css", exist_ok=True)
os.makedirs("static/js", exist_ok=True)

# 全局模型加载器
model_loader = None
inference_engine = None

def initialize_services():
    """初始化所有服务"""
    global model_loader, inference_engine
    
    logger.info("开始初始化服务...")
    
    # 初始化模型加载器
    model_loader = ModelLoader(model_path="./models/youtu-2b")
    success, message = model_loader.load_model(use_gpu=True)
    
    if not success:
        logger.error(f"模型加载失败: {message}")
        # 可以在这里实现重试逻辑或降级方案
        raise RuntimeError(f"模型加载失败: {message}")
    
    # 初始化推理引擎
    inference_engine = init_inference_engine(model_loader)
    
    logger.info("服务初始化完成!")
    return True

@app.before_first_request
def before_first_request():
    """在第一个请求前初始化"""
    try:
        initialize_services()
    except Exception as e:
        logger.error(f"初始化失败: {e}")
        # 可以设置一个标志,让健康检查端点返回不健康状态

# 注册蓝图
app.register_blueprint(chat_bp, url_prefix='/api')
app.register_blueprint(health_bp, url_prefix='/api')

@app.route('/')
def index():
    """主页"""
    return render_template('index.html')

@app.route('/api/status')
def status():
    """服务状态"""
    if model_loader is None or inference_engine is None:
        return jsonify({
            "status": "initializing",
            "message": "服务正在初始化"
        })
    
    model_info = model_loader.get_model_info()
    return jsonify({
        "status": "running",
        "model": model_info,
        "service": "Youtu-2B API Service",
        "version": "1.0.0",
        "endpoints": {
            "chat": "/api/chat",
            "batch_chat": "/api/batch_chat",
            "health": "/api/health",
            "info": "/api/info"
        }
    })

@app.errorhandler(404)
def not_found(error):
    return jsonify({"error": "资源不存在", "success": False}), 404

@app.errorhandler(500)
def internal_error(error):
    logger.error(f"服务器错误: {error}")
    return jsonify({"error": "服务器内部错误", "success": False}), 500

if __name__ == '__main__':
    # 初始化服务
    try:
        initialize_services()
    except Exception as e:
        logger.error(f"启动失败: {e}")
        # 即使模型加载失败,也启动服务(但某些功能不可用)
    
    # 启动Flask应用
    app.run(
        host='0.0.0.0',
        port=8080,
        debug=False,  # 生产环境设为False
        threaded=True
    )

4. WebUI界面设计与实现

4.1 基础HTML界面

现在我们来创建一个美观实用的Web界面:

<!-- templates/index.html -->
<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Youtu-2B 智能对话服务</title>
    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
    <style>
        * {
            margin: 0;
            padding: 0;
            box-sizing: border-box;
        }
        
        body {
            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            min-height: 100vh;
            padding: 20px;
        }
        
        .container {
            max-width: 1200px;
            margin: 0 auto;
            background: white;
            border-radius: 20px;
            box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
            overflow: hidden;
        }
        
        .header {
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            padding: 30px;
            text-align: center;
        }
        
        .header h1 {
            font-size: 2.5rem;
            margin-bottom: 10px;
        }
        
        .header p {
            font-size: 1.1rem;
            opacity: 0.9;
            max-width: 600px;
            margin: 0 auto;
            line-height: 1.6;
        }
        
        .main-content {
            display: flex;
            min-height: 600px;
        }
        
        .chat-container {
            flex: 3;
            display: flex;
            flex-direction: column;
            border-right: 1px solid #e5e7eb;
        }
        
        .chat-history {
            flex: 1;
            overflow-y: auto;
            padding: 20px;
            background: #f9fafb;
        }
        
        .message {
            margin-bottom: 20px;
            display: flex;
            gap: 12px;
        }
        
        .message.user {
            flex-direction: row-reverse;
        }
        
        .avatar {
            width: 40px;
            height: 40px;
            border-radius: 50%;
            display: flex;
            align-items: center;
            justify-content: center;
            font-weight: bold;
            flex-shrink: 0;
        }
        
        .user .avatar {
            background: #667eea;
            color: white;
        }
        
        .assistant .avatar {
            background: #10b981;
            color: white;
        }
        
        .message-content {
            max-width: 70%;
            padding: 12px 16px;
            border-radius: 18px;
            line-height: 1.5;
        }
        
        .user .message-content {
            background: #667eea;
            color: white;
            border-bottom-right-radius: 4px;
        }
        
        .assistant .message-content {
            background: #f3f4f6;
            color: #1f2937;
            border-bottom-left-radius: 4px;
        }
        
        .input-area {
            padding: 20px;
            border-top: 1px solid #e5e7eb;
            background: white;
        }
        
        .input-wrapper {
            display: flex;
            gap: 10px;
        }
        
        textarea {
            flex: 1;
            padding: 12px 16px;
            border: 2px solid #e5e7eb;
            border-radius: 12px;
            font-size: 16px;
            resize: none;
            font-family: inherit;
            transition: border-color 0.3s;
        }
        
        textarea:focus {
            outline: none;
            border-color: #667eea;
        }
        
        .send-btn {
            padding: 0 24px;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            border: none;
            border-radius: 12px;
            font-size: 16px;
            font-weight: 600;
            cursor: pointer;
            transition: transform 0.2s, opacity 0.2s;
        }
        
        .send-btn:hover {
            transform: translateY(-2px);
            opacity: 0.9;
        }
        
        .send-btn:disabled {
            opacity: 0.5;
            cursor: not-allowed;
            transform: none;
        }
        
        .controls {
            display: flex;
            gap: 10px;
            margin-top: 12px;
            flex-wrap: wrap;
        }
        
        .control-group {
            display: flex;
            align-items: center;
            gap: 8px;
        }
        
        .control-group label {
            font-size: 14px;
            color: #6b7280;
        }
        
        .control-group input[type="range"] {
            width: 120px;
        }
        
        .control-group input[type="number"] {
            width: 60px;
            padding: 4px 8px;
            border: 1px solid #d1d5db;
            border-radius: 6px;
            text-align: center;
        }
        
        .sidebar {
            flex: 1;
            padding: 20px;
            background: #f9fafb;
        }
        
        .info-card {
            background: white;
            border-radius: 12px;
            padding: 20px;
            margin-bottom: 20px;
            box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
        }
        
        .info-card h3 {
            color: #1f2937;
            margin-bottom: 12px;
            font-size: 1.1rem;
        }
        
        .info-item {
            display: flex;
            justify-content: space-between;
            padding: 8px 0;
            border-bottom: 1px solid #f3f4f6;
        }
        
        .info-item:last-child {
            border-bottom: none;
        }
        
        .info-label {
            color: #6b7280;
            font-size: 14px;
        }
        
        .info-value {
            color: #1f2937;
            font-weight: 500;
        }
        
        .status-indicator {
            display: inline-flex;
            align-items: center;
            gap: 6px;
            padding: 4px 12px;
            border-radius: 20px;
            font-size: 14px;
            font-weight: 500;
        }
        
        .status-online {
            background: #d1fae5;
            color: #065f46;
        }
        
        .status-offline {
            background: #fee2e2;
            color: #991b1b;
        }
        
        .typing-indicator {
            display: none;
            padding: 12px;
            text-align: center;
            color: #6b7280;
        }
        
        .typing-dots {
            display: inline-flex;
            gap: 4px;
        }
        
        .typing-dots span {
            width: 8px;
            height: 8px;
            background: #9ca3af;
            border-radius: 50%;
            animation: typing 1.4s infinite;
        }
        
        .typing-dots span:nth-child(2) {
            animation-delay: 0.2s;
        }
        
        .typing-dots span:nth-child(3) {
            animation-delay: 0.4s;
        }
        
        @keyframes typing {
            0%, 60%, 100% {
                transform: translateY(0);
                opacity: 0.6;
            }
            30% {
                transform: translateY(-6px);
                opacity: 1;
            }
        }
        
        .examples {
            display: grid;
            gap: 10px;
            margin-top: 15px;
        }
        
        .example-btn {
            padding: 10px 16px;
            background: white;
            border: 1px solid #e5e7eb;
            border-radius: 8px;
            text-align: left;
            font-size: 14px;
            color: #4b5563;
            cursor: pointer;
            transition: all 0.2s;
        }
        
        .example-btn:hover {
            background: #f3f4f6;
            border-color: #d1d5db;
            transform: translateX(4px);
        }
        
        @media (max-width: 768px) {
            .main-content {
                flex-direction: column;
            }
            
            .chat-container {
                border-right: none;
                border-bottom: 1px solid #e5e7eb;
            }
            
            .message-content {
                max-width: 85%;
            }
        }
    </style>
</head>
<body>
    <div class="container">
        <div class="header">
            <h1><i class="fas fa-robot"></i> Youtu-2B 智能对话服务</h1>
            <p>基于腾讯优图实验室的轻量化大语言模型,提供高效、智能的对话体验</p>
        </div>
        
        <div class="main-content">
            <div class="chat-container">
                <div class="chat-history" id="chatHistory">
                    <!-- 聊天记录将在这里显示 -->
                    <div class="message assistant">
                        <div class="avatar">AI</div>
                        <div class="message-content">
                            你好!我是Youtu-2B智能助手,基于腾讯优图实验室的轻量化大语言模型。我可以帮助你解答问题、编写代码、进行逻辑推理等。有什么可以帮你的吗?
                        </div>
                    </div>
                </div>
                
                <div class="typing-indicator" id="typingIndicator">
                    <div class="typing-dots">
                        <span></span>
                        <span></span>
                        <span></span>
                    </div>
                    正在思考...
                </div>
                
                <div class="input-area">
                    <div class="input-wrapper">
                        <textarea 
                            id="messageInput" 
                            placeholder="输入你的问题...(例如:帮我写一个Python快速排序算法)" 
                            rows="3"
                        ></textarea>
                        <button class="send-btn" id="sendBtn">
                            <i class="fas fa-paper-plane"></i> 发送
                        </button>
                    </div>
                    
                    <div class="controls">
                        <div class="control-group">
                            <label for="temperature">温度:</label>
                            <input type="range" id="temperature" min="0.1" max="2.0" step="0.1" value="0.7">
                            <input type="number" id="temperatureValue" min="0.1" max="2.0" step="0.1" value="0.7">
                        </div>
                        
                        <div class="control-group">
                            <label for="maxLength">最大长度:</label>
                            <input type="range" id="maxLength" min="100" max="2048" step="100" value="512">
                            <input type="number" id="maxLengthValue" min="100" max="2048" step="100" value="512">
                        </div>
                        
                        <button class="send-btn" id="clearBtn" style="background: #6b7280; padding: 8px 16px;">
                            <i class="fas fa-trash"></i> 清空对话
                        </button>
                    </div>
                </div>
            </div>
            
            <div class="sidebar">
                <div class="info-card">
                    <h3><i class="fas fa-info-circle"></i> 服务状态</h3>
                    <div class="info-item">
                        <span class="info-label">模型状态:</span>
                        <span class="info-value">
                            <span class="status-indicator status-online" id="modelStatus">
                                <i class="fas fa-circle"></i> 在线
                            </span>
                        </span>
                    </div>
                    <div class="info-item">
                        <span class="info-label">模型名称:</span>
                        <span class="info-value">Youtu-LLM-2B</span>
                    </div>
                    <div class="info-item">
                        <span class="info-label">参数量:</span>
                        <span class="info-value">20亿</span>
                    </div>
                    <div class="info-item">
                        <span class="info-label">上下文长度:</span>
                        <span class="info-value">4096 tokens</span>
                    </div>
                </div>
                
                <div class="info-card">
                    <h3><i class="fas fa-bolt"></i> 性能指标</h3>
                    <div class="info-item">
                        <span class="info-label">响应时间:</span>
                        <span class="info-value" id="responseTime">-</span>
                    </div>
                    <div class="info-item">
                        <span class="info-label">生成速度:</span>
                        <span class="info-value" id="generationSpeed">-</span>
                    </div>
                    <div class="info-item">
                        <span class="info-label">总tokens:</span>
                        <span class="info-value" id="totalTokens">-</span>
                    </div>
                </div>
                
                <div class="info-card">
                    <h3><i class="fas fa-lightbulb"></i> 试试这些例子</h3>
                    <div class="examples">
                        <button class="example-btn" data-prompt="用Python写一个快速排序算法">
                            <i class="fas fa-code"></i> 写快速排序算法
                        </button>
                        <button class="example-btn" data-prompt="解释一下量子计算的基本原理">
                            <i class="fas fa-atom"></i> 解释量子计算
                        </button>
                        <button class="example-btn" data-prompt="帮我写一封工作汇报邮件">
                            <i class="fas fa-envelope"></i> 写工作汇报邮件
                        </button>
                        <button class="example-btn" data-prompt="什么是机器学习?用简单的例子说明">
                            <i class="fas fa-brain"></i> 解释机器学习
                        </button>
                    </div>
                </div>
                
                <div class="info-card">
                    <h3><i class="fas fa-cog"></i> 参数说明</h3>
                    <div class="info-item">
                        <span class="info-label">温度:</span>
                        <span class="info-value">控制随机性 (0.1-2.0)</span>
                    </div>
                    <div class="info-item">
                        <span class="info-label">值越小越确定</span>
                        <span class="info-label">值越大越有创意</span>
                    </div>
                </div>
            </div>
        </div>
    </div>

    <script>
        // DOM元素
        const chatHistory = document.getElementById('chatHistory');
        const messageInput = document.getElementById('messageInput');
        const sendBtn = document.getElementById('sendBtn');
        const clearBtn = document.getElementById('clearBtn');
        const typingIndicator = document.getElementById('typingIndicator');
        
        // 参数控制
        const temperatureSlider = document.getElementById('temperature');
        const temperatureValue = document.getElementById('temperatureValue');
        const maxLengthSlider = document.getElementById('maxLength');
        const maxLengthValue = document.getElementById('maxLengthValue');
        
        // 状态显示
        const responseTimeEl = document.getElementById('responseTime');
        const generationSpeedEl = document.getElementById('generationSpeed');
        const totalTokensEl = document.getElementById('totalTokens');
        const modelStatusEl = document.getElementById('modelStatus');
        
        // 同步滑块和输入框的值
        temperatureSlider.addEventListener('input', () => {
            temperatureValue.value = temperatureSlider.value;
        });
        
        temperatureValue.addEventListener('input', () => {
            let value = parseFloat(temperatureValue.value);
            if (value < 0.1) value = 0.1;
            if (value > 2.0) value = 2.0;
            temperatureValue.value = value;
            temperatureSlider.value = value;
        });
        
        maxLengthSlider.addEventListener('input', () => {
            maxLengthValue.value = maxLengthSlider.value;
        });
        
        maxLengthValue.addEventListener('input', () => {
            let value = parseInt(maxLengthValue.value);
            if (value < 100) value = 100;
            if (value > 2048) value = 2048;
            maxLengthValue.value = value;
            maxLengthSlider.value = value;
        });
        
        // 发送消息
        async function sendMessage() {
            const message = messageInput.value.trim();
            if (!message) return;
            
            // 添加用户消息到聊天记录
            addMessage(message, 'user');
            
            // 清空输入框
            messageInput.value = '';
            
            // 禁用发送按钮
            sendBtn.disabled = true;
            
            // 显示正在输入指示器
            typingIndicator.style.display = 'block';
            scrollToBottom();
            
            try {
                // 获取当前参数
                const temperature = parseFloat(temperatureValue.value);
                const maxLength = parseInt(maxLengthValue.value);
                
                // 发送请求到后端
                const response = await fetch('/api/chat', {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/json',
                    },
                    body: JSON.stringify({
                        prompt: message,
                        temperature: temperature,
                        max_length: maxLength,
                        top_p: 0.9
                    })
                });
                
                const data = await response.json();
                
                // 隐藏正在输入指示器
                typingIndicator.style.display = 'none';
                
                if (data.success) {
                    // 添加AI回复到聊天记录
                    addMessage(data.response, 'assistant');
                    
                    // 更新性能指标
                    updatePerformanceMetrics(data);
                } else {
                    // 显示错误信息
                    addMessage(`抱歉,出错了:${data.error || '未知错误'}`, 'assistant');
                }
                
            } catch (error) {
                console.error('发送消息失败:', error);
                typingIndicator.style.display = 'none';
                addMessage('抱歉,网络连接出现问题,请稍后重试。', 'assistant');
            } finally {
                // 重新启用发送按钮
                sendBtn.disabled = false;
                // 聚焦到输入框
                messageInput.focus();
            }
        }
        
        // 添加消息到聊天记录
        function addMessage(content, sender) {
            const messageDiv = document.createElement('div');
            messageDiv.className = `message ${sender}`;
            
            const avatarDiv = document.createElement('div');
            avatarDiv.className = 'avatar';
            avatarDiv.textContent = sender === 'user' ? '你' : 'AI';
            
            const contentDiv = document.createElement('div');
            contentDiv.className = 'message-content';
            
            // 如果是AI回复,保留换行格式
            if (sender === 'assistant') {
                contentDiv.innerHTML = content.replace(/\n/g, '<br>');
            } else {
                contentDiv.textContent = content;
            }
            
            messageDiv.appendChild(avatarDiv);
            messageDiv.appendChild(contentDiv);
            chatHistory.appendChild(messageDiv);
            
            // 滚动到底部
            scrollToBottom();
        }
        
        // 更新性能指标
        function updatePerformanceMetrics(data) {
            if (data.inference_time) {
                responseTimeEl.textContent = `${data.inference_time}秒`;
            }
            if (data.tokens_per_second) {
                generationSpeedEl.textContent = `${data.tokens_per_second} tokens/秒`;
            }
            if (data.total_tokens) {
                totalTokensEl.textContent = data.total_tokens;
            }
        }
        
        // 滚动到底部
        function scrollToBottom() {
            chatHistory.scrollTop = chatHistory.scrollHeight;
        }
        
        // 清空对话
        function clearChat() {
            // 保留第一条欢迎消息
            const welcomeMessage = chatHistory.querySelector('.message.assistant');
            chatHistory.innerHTML = '';
            if (welcomeMessage) {
                chatHistory.appendChild(welcomeMessage);
            }
            
            // 重置性能指标
            responseTimeEl.textContent = '-';
            generationSpeedEl.textContent = '-';
            totalTokensEl.textContent = '-';
        }
        
        // 检查服务状态
        async function checkServiceStatus() {
            try {
                const response = await fetch('/api/health');
                const data = await response.json();
                
                if (data.status === 'healthy') {
                    modelStatusEl.className = 'status-indicator status-online';
                    modelStatusEl.innerHTML = '<i class="fas fa-circle"></i> 在线';
                } else {
                    modelStatusEl.className = 'status-indicator status-offline';
                    modelStatusEl.innerHTML = '<i class="fas fa-circle"></i> 离线';
                }
            } catch (error) {
                modelStatusEl.className = 'status-indicator status-offline';
                modelStatusEl.innerHTML = '<i class="fas fa-circle"></i> 离线';
            }
        }
        
        // 事件监听
        sendBtn.addEventListener('click', sendMessage);
        
        messageInput.addEventListener('keydown', (e) => {
            if (e.key === 'Enter' && !e.shiftKey) {
                e.preventDefault();
                sendMessage();
            }
        });
        
        clearBtn.addEventListener('click', clearChat);
        
        // 示例按钮点击事件
        document.querySelectorAll('.example-btn').forEach(btn => {
            btn.addEventListener('click', () => {
                const prompt = btn.getAttribute('data-prompt');
                messageInput.value = prompt;
                messageInput.focus();
            });
        });
        
        // 自动调整输入框高度
        messageInput.addEventListener('input', function() {
            this.style.height = 'auto';
            this.style.height = (this.scrollHeight) + 'px';
        });
        
        // 定期检查服务状态
        checkServiceStatus();
        setInterval(checkServiceStatus, 30000); // 每30秒检查一次
        
        // 初始聚焦到输入框
        messageInput.focus();
    </script>
</body>
</html>

4.2 使用Gradio快速构建界面(备选方案)

如果你想要更快速地搭建界面,也可以使用Gradio。这是一个更简单的方案:

# gradio_app.py
import gradio as gr
import requests
import json
from typing import Dict, Any

# API基础URL
API_BASE = "http://localhost:8080/api"

def chat_with_model(
    message: str,
    history: list,
    temperature: float,
    max_length: int
) -> str:
    """
    与模型对话的Gradio接口函数
    """
    try:
        # 准备请求数据
        data = {
            "prompt": message,
            "temperature": temperature,
            "max_length": max_length,
            "top_p": 0.9
        }
        
        # 发送请求
        response = requests.post(
            f"{API_BASE}/chat",
            json=data,
            timeout=60
        )
        
        if response.status_code == 200:
            result = response.json()
            if result.get("success"):
                return result["response"]
            else:
                return f"错误: {result.get('error', '未知错误')}"
        else:
            return f"请求失败: HTTP {response.status_code}"
            
    except requests.exceptions.RequestException as e:
        return f"网络错误: {str(e)}"
    except Exception as e:
        return f"处理错误: {str(e)}"

def get_model_info() -> Dict[str, Any]:
    """获取模型信息"""
    try:
        response = requests.get(f"{API_BASE}/info", timeout=5)
        if response.status_code == 200:
            return response.json()
    except:
        pass
    return {}

def create_gradio_interface():
    """创建Gradio界面"""
    
    # 获取模型信息
    model_info = get_model_info()
    model_name = model_info.get("model", "Youtu-LLM-2B")
    capabilities = model_info.get("capabilities", ["对话", "代码生成", "数学推理"])
    
    # 创建界面
    with gr.Blocks(
        title=f"{model_name} 智能对话",
        theme=gr.themes.Soft()
    ) as demo:
        
        gr.Markdown(f"""
        # 🤖 {model_name} 智能对话服务
        
        基于腾讯优图实验室的轻量化大语言模型,支持:{", ".join(capabilities)}
        
        **特点**:
        - 🚀 轻量高效:仅需极少显存即可流畅运行
        - 💬 智能对话:深度优化的中文对话能力
        - 🛠️ 代码辅助:支持多种编程语言
        - 🧮 数学推理:擅长逻辑推理和数学计算
        
        **开始对话**:在下方输入你的问题,按Enter发送,Shift+Enter换行
        """)
        
        # 聊天历史
        chatbot = gr.Chatbot(
            label="对话历史",
            height=500,
            show_copy_button=True
        )
        
        # 状态信息
        with gr.Row():
            with gr.Column(scale=1):
                status_text = gr.Textbox(
                    label="服务状态",
                    value="🟢 在线",
                    interactive=False
                )
            with gr.Column(scale=2):
                model_info_text = gr.Textbox(
                    label="模型信息",
                    value=f"模型: {model_name} | 上下文长度: 4096 tokens",
                    interactive=False
                )
        
        # 输入区域
        with gr.Row():
            msg = gr.Textbox(
                label="输入消息",
                placeholder="输入你的问题...(例如:帮我写一个Python快速排序算法)",
                lines=3,
                scale=4
            )
            submit_btn = gr.Button("发送", variant="primary", scale=1)
        
        # 参数控制
        with gr.Accordion("高级参数", open=False):
            with gr.Row():
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="温度 (Temperature)",
                    info="控制回复的随机性,值越小越确定,值越大越有创意"
                )
                max_length = gr.Slider(
                    minimum=100,
                    maximum=2048,
                    value=512,
                    step=100,
                    label="最大生成长度",
                    info="控制生成文本的最大长度"
                )
        
        # 示例问题
        with gr.Accordion("示例问题", open=True):
            examples = gr.Examples(
                examples=[
                    ["用Python写一个快速排序算法"],
                    ["解释一下量子计算的基本原理"],
                    ["帮我写一封工作汇报邮件"],
                    ["什么是机器学习?用简单的例子说明"],
                    ["写一个关于人工智能的短故事"],
                    ["如何提高学习效率?给出具体建议"]
                ],
                inputs=msg,
                label="点击使用示例"
            )
        
        # 控制按钮
        with gr.Row():
            clear_btn = gr.Button("清空对话", variant="secondary")
            retry_btn = gr.Button("重试上次", variant="secondary")
        
        # 系统信息
        with gr.Accordion("系统信息", open=False):
            gr.Markdown(f"""
            ### 技术规格
            - **模型**: {model_name}
            - **参数量**: 20亿
            - **上下文长度**: 4096 tokens
            - **支持语言**: 中文、英文
            - **推理后端**: Flask + Transformers
            
            ### API接口
            - `POST /api/chat` - 聊天接口
            - `POST /api/batch_chat` - 批量聊天接口
            - `GET /api/health` - 健康检查
            - `GET /api/info` - 模型信息
            
            ### 使用提示
            1. 问题描述越详细,回答越准确
            2. 可以要求模型以特定格式回复(如代码、列表、表格等)
            3. 温度参数影响创造性,编程任务建议使用较低温度(0.1-0.5)
            4. 对话任务建议使用中等温度(0.7-1.0)
            """)
        
        # 事件处理
        def respond(message, chat_history, temp, max_len):
            """处理用户输入并生成回复"""
            if not message.strip():
                return "", chat_history
            
            # 添加用户消息到历史
            chat_history.append((message, None))
            
            # 生成回复
            response = chat_with_model(message, chat_history, temp, max_len)
            
            # 更新最后一条消息的回复
            chat_history[-1] = (message, response)
            
            return "", chat_history
        
        # 连接事件
        submit_btn.click(
            respond,
            [msg, chatbot, temperature, max_length],
            [msg, chatbot]
        )
        
        msg.submit(
            respond,
            [msg, chatbot, temperature, max_length],
            [msg, chatbot]
        )
        
        clear_btn.click(lambda: [], None, chatbot)
        
        def retry_last(chat_history):
            """重试最后一条消息"""
            if not chat_history:
                return chat_history
            
            # 获取最后一条用户消息
            last_message = chat_history[-1][0] if chat_history[-1][1] is None else None
            if last_message:
                # 重新生成回复
                response = chat_with_model(
                    last_message,
                    chat_history[:-1],
                    temperature.value,
                    max_length.value
                )
                chat_history[-1] = (last_message, response)
            
            return chat_history
        
        retry_btn.click(retry_last, chatbot, chatbot)
        
        # 定期更新状态
        def update_status():
            """更新服务状态"""
            try:
                response = requests.get(f"{API_BASE}/health", timeout=5)
                if response.status_code == 200:
                    return "🟢 在线"
                else:
                    return "🔴 离线"
            except:
                return "🔴 离线"
        
        # 状态自动更新
        demo.load(update_status, None, status_text, every=30)
    
    return demo

if __name__ == "__main__":
    # 创建并启动Gradio界面
    demo = create_gradio_interface()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )

5. 部署与优化建议

5.1 生产环境部署

现在我们已经有了完整的服务代码,接下来看看如何部署到生产环境。

使用Gunicorn部署Flask应用:

首先安装Gunicorn:

pip install gunicorn

创建Gunicorn配置文件:

# gunicorn_config.py
import multiprocessing

# 服务器配置
bind = "0.0.0.0:8080"
workers = multiprocessing.cpu_count() * 2 + 1
worker_class = "sync"
worker_connections = 1000
timeout = 120
keepalive = 2

# 日志配置
accesslog = "./logs/access.log"
errorlog = "./logs/error.log"
loglevel = "info"
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'

# 进程配置
daemon = False
pidfile = "./gunicorn.pid"
umask = 0
user = None
group = None
tmp_upload_dir = None

# 服务器钩子
def post_fork(server, worker):
    server.log.info("Worker spawned (pid: %s)", worker.pid)

def pre_fork(server, worker):
    pass

def pre_exec(server):
    server.log.info("Forked child, re-executing.")

def when_ready(server):
    server.log.info("Server is ready. Spawning workers")

def worker_int(worker):
    worker.log.info("worker received INT or QUIT signal")

def worker_abort(worker):
    worker.log.info("worker received SIGABRT signal")

启动服务:

# 使用Gunicorn启动
gunicorn -c gunicorn_config.py app:app

# 或者使用简单命令
gunicorn -w 4 -b 0.0.0.0:8080 app:app

使用Docker容器化部署:

创建Dockerfile:

# Dockerfile
FROM python:3.9-slim

# 设置工作目录
WORKDIR /app

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

# 复制依赖文件
COPY requirements.txt .

# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

# 复制应用代码
COPY . .

# 创建日志目录
RUN mkdir -p logs

# 下载模型(可以在构建时下载,或运行时下载)
# 这里假设模型已经预下载到models目录
# 如果没有,可以取消下面的注释
# RUN python download_model.py

# 暴露端口
EXPOSE 8080

# 启动命令
CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:8080", "app:app"]

创建docker-compose.yml:

# docker-compose.yml
version: '3.8'

services:
  youtu-2b-service:
    build: .
    container_name: youtu-2b-service
    ports:
      - "8080:8080"
    volumes:
      - ./models:/app/models
      - ./logs:/app/logs
    environment:
      - PYTHONUNBUFFERED=1
      - TZ=Asia/Shanghai
    restart: unless-stopped
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: all
              capabilities: [gpu]
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8080/api/health"]
      interval: 30s
      timeout: 10s
      retries: 3
      start_period: 40s

构建和运行:

# 构建镜像
docker-compose build

# 启动服务
docker-compose up -d

# 查看日志
docker-compose logs -f

# 停止服务
docker-compose down

5.2 性能优化建议

模型加载优化:

# core/model_loader_optimized.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from contextlib import contextmanager
import logging

logger = logging.getLogger(__name__)

class OptimizedModelLoader:
    """优化版的模型加载器"""
    
    def __init__(self, model_path: str = "./models/youtu-2b"):
        self.model_path = model_path
        self.tokenizer = None
        self.model = None
        self.device = None
        
    def load_model_with_optimizations(self, use_gpu: bool = True):
        """带优化的模型加载"""
        try:
            logger.info("开始加载优化版模型...")
            
            # 确定设备
            if use_gpu and torch.cuda.is_available():
                self.device = torch.device("cuda")
                torch.backends.cudnn.benchmark = True  # 启用cudnn自动优化
                torch.backends.cuda.matmul.allow_tf32 = True  # 允许TF32
            else:
                self.device = torch.device("cpu")
                # 启用CPU优化
                torch.set_num_threads(torch.get_num_threads())
            
            # 加载分词器(使用缓存)
            logger.info("加载分词器...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path,
                trust_remote_code=True,
                cache_dir="./cache"  # 指定缓存目录
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # 模型加载配置
            load_config = {
                "trust_remote_code": True,
                "cache_dir": "./cache",
                "low_cpu_mem_usage": True,  # 减少CPU内存使用
            }
            
            # 根据设备选择数据类型和加载方式
            if self.device.type == "cuda":
                load_config.update({
                    "torch_dtype": torch.float16,  # 使用半精度
                    "device_map": "auto",  # 自动分配到多个GPU
                })
            else:
                load_config.update({
                    "torch_dtype": torch.float32,
                })
            
            # 加载模型
            logger.info("加载模型(优化版)...")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                **load_config
            )
            
            # 如果使用CPU,启用更多优化
            if self.device.type == "cpu":
                self.model = self.model.to(self.device)
                # 启用CPU推理优化
                try:
                    import intel_extension_for_pytorch as ipex
                    self.model = ipex.optimize(self.model)
                    logger.info("已启用Intel PyTorch扩展优化")
                except ImportError:
                    logger.info("未安装Intel PyTorch扩展,使用标准CPU模式")
            
            # 设置为评估模式
            self.model.eval()
            
            # 启用更好的推理模式
            if hasattr(self.model, "prepare_for_inference"):
                self.model.prepare_for_inference()
            
            logger.info("优化版模型加载完成!")
            return True, "模型加载成功"
            
        except Exception as e:
            error_msg = f"模型加载失败: {str(e)}"
            logger.error(error_msg)
            return False, error_msg
    
    @contextmanager
    def inference_context(self):
        """推理上下文管理器,优化内存使用"""
        try:
            if self.device.type == "cuda":
                # 清空CUDA缓存
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            
            yield self.model, self.tokenizer
            
        finally:
            if self.device.type == "cuda":
                torch.cuda.empty_cache()
    
    def optimize_for_inference(self):
        """进一步优化推理性能"""
        if self.model is None:
            return
        
        logger.info("应用推理优化...")
        
        # 启用更好的KV缓存(如果支持)
        if hasattr(self.model.config, "use_cache"):
            self.model.config.use_cache = True
        
        # 启用梯度检查点(节省内存)
        if hasattr(self.model, "gradient_checkpointing_enable"):
            self.model.gradient_checkpointing_enable()
        
        # 如果使用CPU,应用更多优化
        if self.device.type == "cpu":
            # 设置线程数
            torch.set_num_threads(min(4, torch.get_num_threads()))
            
            # 启用更好的内存布局
            try:
                self.model 
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐