Youtu-2B模型服务封装:Flask+WebUI完整指南
本文介绍了如何在星图GPU平台上自动化部署🚀 Youtu LLM 智能对话服务 - Youtu-2B镜像,快速搭建轻量级AI对话服务。该服务基于腾讯优图实验室的20亿参数模型,通过Flask后端和WebUI界面封装,可轻松应用于智能问答、代码生成等日常对话场景,显著降低大模型应用门槛。
Youtu-2B模型服务封装:Flask+WebUI完整指南
1. 引言:为什么选择Youtu-2B?
如果你正在寻找一个既轻量又聪明的AI对话助手,那么Youtu-2B模型绝对值得你花时间了解一下。想象一下,一个只有20亿参数的“小个子”模型,却能在数学题、编程任务和日常聊天中表现出惊人的能力,而且对硬件要求极低——普通消费级显卡甚至CPU都能流畅运行。
这就是腾讯优图实验室推出的Youtu-LLM-2B模型带来的惊喜。但模型本身只是开始,如何把它变成一个随时可用的服务,才是真正发挥价值的关键。本文将带你一步步完成从模型到服务的完整封装,使用Flask构建稳健的后端API,并搭配一个简洁美观的WebUI界面,让你和你的团队能够轻松地与这个智能助手对话。
无论你是想快速搭建一个内部知识问答系统,还是需要一个编程辅助工具,或是单纯想体验轻量级大模型的魅力,这篇指南都将为你提供清晰的路径。我们不会涉及复杂的理论推导,而是聚焦于“怎么做”——从环境准备到代码实现,从界面设计到部署优化,每个环节都有详细的步骤和可运行的代码示例。
2. 环境准备与快速部署
2.1 系统要求与依赖安装
在开始之前,我们先看看需要准备什么。好消息是,Youtu-2B对硬件的要求相当友好。
基础环境要求:
- 操作系统:Ubuntu 20.04/22.04或CentOS 7/8(Windows用户建议使用WSL2)
- Python版本:Python 3.8或3.9
- 内存:至少8GB RAM
- 存储空间:10GB可用空间(用于模型文件和依赖包)
- 显卡:可选但推荐(有GPU会更快)
- NVIDIA GPU(4GB显存以上)
- 或仅使用CPU(速度稍慢但完全可用)
一键安装所有依赖:
打开终端,执行以下命令来搭建完整的环境:
# 创建项目目录
mkdir youtu-2b-service && cd youtu-2b-service
# 创建Python虚拟环境(推荐)
python -m venv venv
source venv/bin/activate # Linux/Mac
# 或 venv\Scripts\activate # Windows
# 安装PyTorch(根据你的CUDA版本选择)
# 如果没有GPU或CUDA 11.7
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# 如果有GPU(以CUDA 11.8为例)
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# 安装模型推理和Web框架
pip install transformers>=4.30.0
pip install flask>=2.3.0
pip install flask-cors>=4.0.0
# 安装WebUI相关依赖
pip install gradio>=3.40.0
# 安装其他工具库
pip install requests>=2.31.0
pip install sentencepiece>=0.1.99 # 分词器需要
安装过程可能需要几分钟,取决于你的网络速度。如果遇到下载慢的问题,可以考虑使用国内的镜像源:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple [包名]
2.2 模型下载与验证
环境准备好后,我们需要下载Youtu-2B模型。这里有两种方式:
方式一:直接从Hugging Face下载(推荐)
# download_model.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# 指定模型路径
model_name = "Tencent-YouTu-Research/Youtu-LLM-2B"
save_path = "./models/youtu-2b"
# 创建保存目录
os.makedirs(save_path, exist_ok=True)
print("开始下载Youtu-2B模型...")
print("这可能需要一些时间,模型大小约4GB")
# 下载tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype="auto" # 自动选择数据类型
)
# 保存到本地
tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)
print(f"模型已保存到: {save_path}")
print("下载完成!")
方式二:使用预下载的模型文件
如果你已经下载了模型文件,可以直接放到指定目录:
# 假设模型文件在本地路径
mkdir -p ./models/youtu-2b
cp -r /path/to/your/model/files/* ./models/youtu-2b/
验证模型是否正常:
下载完成后,运行一个简单的测试脚本来验证模型能否正常工作:
# test_model.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_path = "./models/youtu-2b"
print("加载模型和分词器...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" # 自动分配到可用设备
)
# 测试推理
test_prompt = "你好,请介绍一下你自己。"
inputs = tokenizer(test_prompt, return_tensors="pt")
# 将输入移到模型所在的设备
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
print("开始推理...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"测试回复: {response}")
print("模型测试通过!")
如果看到模型正常回复,说明环境配置成功。接下来我们开始构建服务。
3. Flask后端服务封装
3.1 项目结构设计
在开始编码之前,我们先规划一下项目的目录结构。一个好的结构能让代码更清晰,也便于后期维护。
youtu-2b-service/
├── app.py # Flask主应用
├── config.py # 配置文件
├── requirements.txt # 依赖列表
├── models/
│ └── youtu-2b/ # 模型文件
├── api/
│ ├── __init__.py
│ ├── chat.py # 聊天API
│ └── health.py # 健康检查API
├── core/
│ ├── __init__.py
│ ├── model_loader.py # 模型加载器
│ └── inference.py # 推理引擎
├── static/ # 静态文件
│ ├── css/
│ └── js/
├── templates/ # 模板文件
│ └── index.html
└── utils/
├── __init__.py
└── logger.py # 日志工具
创建这个结构:
mkdir -p api core static/css static/js templates utils
touch app.py config.py requirements.txt
touch api/__init__.py api/chat.py api/health.py
touch core/__init__.py core/model_loader.py core/inference.py
touch utils/__init__.py utils/logger.py
touch templates/index.html
3.2 核心模型加载器
我们先实现模型加载模块,这是整个服务的基础:
# core/model_loader.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from typing import Tuple, Optional
logger = logging.getLogger(__name__)
class ModelLoader:
"""模型加载和管理类"""
def __init__(self, model_path: str = "./models/youtu-2b"):
self.model_path = model_path
self.tokenizer = None
self.model = None
self.device = None
def load_model(self, use_gpu: bool = True) -> Tuple[bool, str]:
"""
加载模型和分词器
Args:
use_gpu: 是否使用GPU
Returns:
(成功标志, 消息)
"""
try:
logger.info(f"开始加载模型,路径: {self.model_path}")
# 确定设备
if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
logger.info(f"使用GPU: {torch.cuda.get_device_name(0)}")
else:
self.device = torch.device("cpu")
logger.info("使用CPU")
# 加载分词器
logger.info("加载分词器...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True
)
# 设置pad_token(如果不存在)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# 加载模型
logger.info("加载模型...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto" if self.device.type == "cuda" else None
)
# 如果使用CPU,需要手动移动模型
if self.device.type == "cpu":
self.model = self.model.to(self.device)
# 设置为评估模式
self.model.eval()
logger.info("模型加载完成!")
return True, "模型加载成功"
except Exception as e:
error_msg = f"模型加载失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
def get_model_info(self) -> dict:
"""获取模型信息"""
if self.model is None:
return {"status": "not_loaded"}
info = {
"status": "loaded",
"model_path": self.model_path,
"device": str(self.device),
"model_type": type(self.model).__name__,
"parameters": sum(p.numel() for p in self.model.parameters()),
}
if self.device.type == "cuda":
info["gpu_memory"] = f"{torch.cuda.memory_allocated() / 1024**3:.2f} GB"
return info
def unload_model(self):
"""卸载模型释放内存"""
if self.model is not None:
del self.model
self.model = None
torch.cuda.empty_cache() if torch.cuda.is_available() else None
logger.info("模型已卸载")
3.3 推理引擎实现
接下来实现推理引擎,这是处理用户请求的核心:
# core/inference.py
import torch
import time
from typing import Dict, List, Optional
import logging
from .model_loader import ModelLoader
logger = logging.getLogger(__name__)
class InferenceEngine:
"""推理引擎"""
def __init__(self, model_loader: ModelLoader):
self.model_loader = model_loader
self.tokenizer = model_loader.tokenizer
self.model = model_loader.model
self.device = model_loader.device
def generate_response(
self,
prompt: str,
max_length: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
**kwargs
) -> Dict:
"""
生成回复
Args:
prompt: 输入文本
max_length: 最大生成长度
temperature: 温度参数(控制随机性)
top_p: 核采样参数
repetition_penalty: 重复惩罚
Returns:
包含回复和元数据的字典
"""
if self.model is None or self.tokenizer is None:
return {
"response": "模型未加载,请稍后重试",
"error": "model_not_loaded"
}
try:
start_time = time.time()
# 编码输入
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length
)
# 移动到设备
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 生成配置
generate_kwargs = {
"max_new_tokens": min(1024, max_length),
"temperature": max(0.1, min(2.0, temperature)),
"top_p": max(0.1, min(1.0, top_p)),
"repetition_penalty": max(1.0, min(2.0, repetition_penalty)),
"do_sample": temperature > 0.1,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# 更新额外参数
generate_kwargs.update(kwargs)
# 生成回复
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**generate_kwargs
)
# 解码输出
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
# 计算耗时
inference_time = time.time() - start_time
# 统计信息
input_tokens = inputs["input_ids"].shape[1]
output_tokens = outputs.shape[1] - input_tokens
tokens_per_second = output_tokens / inference_time if inference_time > 0 else 0
logger.info(f"推理完成: {input_tokens} -> {output_tokens} tokens, "
f"耗时: {inference_time:.2f}s, "
f"速度: {tokens_per_second:.1f} tokens/s")
return {
"response": response.strip(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"inference_time": round(inference_time, 3),
"tokens_per_second": round(tokens_per_second, 1),
"model": "Youtu-LLM-2B",
"success": True
}
except Exception as e:
error_msg = f"推理失败: {str(e)}"
logger.error(error_msg)
return {
"response": "生成回复时出现错误",
"error": error_msg,
"success": False
}
def batch_generate(
self,
prompts: List[str],
**kwargs
) -> List[Dict]:
"""批量生成回复"""
results = []
for prompt in prompts:
result = self.generate_response(prompt, **kwargs)
results.append(result)
return results
3.4 Flask API实现
现在我们来创建Flask API端点:
# api/chat.py
from flask import Blueprint, request, jsonify
import logging
from typing import Dict, Any
import json
from core.inference import InferenceEngine
# 创建蓝图
chat_bp = Blueprint('chat', __name__)
logger = logging.getLogger(__name__)
# 全局推理引擎实例
_inference_engine = None
def init_inference_engine(model_loader):
"""初始化推理引擎"""
global _inference_engine
if _inference_engine is None:
_inference_engine = InferenceEngine(model_loader)
return _inference_engine
@chat_bp.route('/chat', methods=['POST'])
def chat():
"""
聊天API端点
请求格式:
{
"prompt": "你的问题",
"max_length": 512,
"temperature": 0.7,
"top_p": 0.9,
"stream": false
}
"""
try:
# 解析请求数据
if request.is_json:
data = request.get_json()
else:
data = request.form.to_dict()
# 获取必需参数
prompt = data.get('prompt', '').strip()
if not prompt:
return jsonify({
"error": "prompt参数不能为空",
"success": False
}), 400
# 获取可选参数
max_length = int(data.get('max_length', 512))
temperature = float(data.get('temperature', 0.7))
top_p = float(data.get('top_p', 0.9))
stream = data.get('stream', False)
# 参数验证
if max_length < 1 or max_length > 4096:
return jsonify({
"error": "max_length必须在1-4096之间",
"success": False
}), 400
if temperature < 0.1 or temperature > 2.0:
return jsonify({
"error": "temperature必须在0.1-2.0之间",
"success": False
}), 400
# 生成回复
if _inference_engine is None:
return jsonify({
"error": "推理引擎未初始化",
"success": False
}), 503
# 流式输出(简化版)
if stream:
# 这里可以实现真正的流式输出
# 为简化,我们先返回完整结果
result = _inference_engine.generate_response(
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p
)
return jsonify(result)
else:
result = _inference_engine.generate_response(
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p
)
return jsonify(result)
except ValueError as e:
logger.error(f"参数错误: {str(e)}")
return jsonify({
"error": f"参数格式错误: {str(e)}",
"success": False
}), 400
except Exception as e:
logger.error(f"处理请求时出错: {str(e)}")
return jsonify({
"error": f"服务器内部错误: {str(e)}",
"success": False
}), 500
@chat_bp.route('/batch_chat', methods=['POST'])
def batch_chat():
"""批量聊天API"""
try:
data = request.get_json()
prompts = data.get('prompts', [])
if not isinstance(prompts, list) or len(prompts) == 0:
return jsonify({
"error": "prompts必须是非空数组",
"success": False
}), 400
if len(prompts) > 10: # 限制批量大小
return jsonify({
"error": "批量请求最多支持10个prompt",
"success": False
}), 400
# 获取参数
max_length = int(data.get('max_length', 512))
temperature = float(data.get('temperature', 0.7))
if _inference_engine is None:
return jsonify({
"error": "推理引擎未初始化",
"success": False
}), 503
results = _inference_engine.batch_generate(
prompts=prompts,
max_length=max_length,
temperature=temperature
)
return jsonify({
"results": results,
"count": len(results),
"success": True
})
except Exception as e:
logger.error(f"批量处理失败: {str(e)}")
return jsonify({
"error": f"处理失败: {str(e)}",
"success": False
}), 500
# api/health.py
from flask import Blueprint, jsonify
import psutil
import torch
health_bp = Blueprint('health', __name__)
@health_bp.route('/health', methods=['GET'])
def health_check():
"""健康检查端点"""
try:
# 基础系统信息
cpu_percent = psutil.cpu_percent(interval=0.1)
memory = psutil.virtual_memory()
health_data = {
"status": "healthy",
"service": "youtu-2b-api",
"timestamp": psutil.boot_time(),
"system": {
"cpu_percent": cpu_percent,
"memory_percent": memory.percent,
"memory_available_gb": round(memory.available / (1024**3), 2),
"memory_total_gb": round(memory.total / (1024**3), 2)
}
}
# GPU信息(如果可用)
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated()
gpu_memory_reserved = torch.cuda.memory_reserved()
health_data["gpu"] = {
"available": True,
"device_count": torch.cuda.device_count(),
"current_device": torch.cuda.current_device(),
"device_name": torch.cuda.get_device_name(0),
"memory_allocated_gb": round(gpu_memory / (1024**3), 2),
"memory_reserved_gb": round(gpu_memory_reserved / (1024**3), 2)
}
else:
health_data["gpu"] = {"available": False}
return jsonify(health_data)
except Exception as e:
return jsonify({
"status": "unhealthy",
"error": str(e)
}), 500
@health_bp.route('/info', methods=['GET'])
def model_info():
"""模型信息端点"""
# 这里需要从全局获取模型信息
# 为了简化,我们先返回基础信息
info = {
"model": "Youtu-LLM-2B",
"provider": "Tencent YouTu Research",
"parameters": "2B",
"context_length": 4096,
"languages": ["中文", "英文"],
"capabilities": ["对话", "代码生成", "数学推理", "文案创作"]
}
return jsonify(info)
3.5 Flask主应用
最后,我们把所有部分整合到Flask主应用中:
# app.py
from flask import Flask, render_template, jsonify
from flask_cors import CORS
import logging
from datetime import datetime
import os
from core.model_loader import ModelLoader
from api.chat import chat_bp, init_inference_engine
from api.health import health_bp
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f"logs/app_{datetime.now().strftime('%Y%m%d')}.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 创建Flask应用
app = Flask(__name__)
CORS(app) # 允许跨域请求
# 配置
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB最大请求
app.config['JSON_AS_ASCII'] = False # 支持中文
# 创建必要的目录
os.makedirs("logs", exist_ok=True)
os.makedirs("static/css", exist_ok=True)
os.makedirs("static/js", exist_ok=True)
# 全局模型加载器
model_loader = None
inference_engine = None
def initialize_services():
"""初始化所有服务"""
global model_loader, inference_engine
logger.info("开始初始化服务...")
# 初始化模型加载器
model_loader = ModelLoader(model_path="./models/youtu-2b")
success, message = model_loader.load_model(use_gpu=True)
if not success:
logger.error(f"模型加载失败: {message}")
# 可以在这里实现重试逻辑或降级方案
raise RuntimeError(f"模型加载失败: {message}")
# 初始化推理引擎
inference_engine = init_inference_engine(model_loader)
logger.info("服务初始化完成!")
return True
@app.before_first_request
def before_first_request():
"""在第一个请求前初始化"""
try:
initialize_services()
except Exception as e:
logger.error(f"初始化失败: {e}")
# 可以设置一个标志,让健康检查端点返回不健康状态
# 注册蓝图
app.register_blueprint(chat_bp, url_prefix='/api')
app.register_blueprint(health_bp, url_prefix='/api')
@app.route('/')
def index():
"""主页"""
return render_template('index.html')
@app.route('/api/status')
def status():
"""服务状态"""
if model_loader is None or inference_engine is None:
return jsonify({
"status": "initializing",
"message": "服务正在初始化"
})
model_info = model_loader.get_model_info()
return jsonify({
"status": "running",
"model": model_info,
"service": "Youtu-2B API Service",
"version": "1.0.0",
"endpoints": {
"chat": "/api/chat",
"batch_chat": "/api/batch_chat",
"health": "/api/health",
"info": "/api/info"
}
})
@app.errorhandler(404)
def not_found(error):
return jsonify({"error": "资源不存在", "success": False}), 404
@app.errorhandler(500)
def internal_error(error):
logger.error(f"服务器错误: {error}")
return jsonify({"error": "服务器内部错误", "success": False}), 500
if __name__ == '__main__':
# 初始化服务
try:
initialize_services()
except Exception as e:
logger.error(f"启动失败: {e}")
# 即使模型加载失败,也启动服务(但某些功能不可用)
# 启动Flask应用
app.run(
host='0.0.0.0',
port=8080,
debug=False, # 生产环境设为False
threaded=True
)
4. WebUI界面设计与实现
4.1 基础HTML界面
现在我们来创建一个美观实用的Web界面:
<!-- templates/index.html -->
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Youtu-2B 智能对话服务</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
border-radius: 20px;
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 30px;
text-align: center;
}
.header h1 {
font-size: 2.5rem;
margin-bottom: 10px;
}
.header p {
font-size: 1.1rem;
opacity: 0.9;
max-width: 600px;
margin: 0 auto;
line-height: 1.6;
}
.main-content {
display: flex;
min-height: 600px;
}
.chat-container {
flex: 3;
display: flex;
flex-direction: column;
border-right: 1px solid #e5e7eb;
}
.chat-history {
flex: 1;
overflow-y: auto;
padding: 20px;
background: #f9fafb;
}
.message {
margin-bottom: 20px;
display: flex;
gap: 12px;
}
.message.user {
flex-direction: row-reverse;
}
.avatar {
width: 40px;
height: 40px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-weight: bold;
flex-shrink: 0;
}
.user .avatar {
background: #667eea;
color: white;
}
.assistant .avatar {
background: #10b981;
color: white;
}
.message-content {
max-width: 70%;
padding: 12px 16px;
border-radius: 18px;
line-height: 1.5;
}
.user .message-content {
background: #667eea;
color: white;
border-bottom-right-radius: 4px;
}
.assistant .message-content {
background: #f3f4f6;
color: #1f2937;
border-bottom-left-radius: 4px;
}
.input-area {
padding: 20px;
border-top: 1px solid #e5e7eb;
background: white;
}
.input-wrapper {
display: flex;
gap: 10px;
}
textarea {
flex: 1;
padding: 12px 16px;
border: 2px solid #e5e7eb;
border-radius: 12px;
font-size: 16px;
resize: none;
font-family: inherit;
transition: border-color 0.3s;
}
textarea:focus {
outline: none;
border-color: #667eea;
}
.send-btn {
padding: 0 24px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 12px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, opacity 0.2s;
}
.send-btn:hover {
transform: translateY(-2px);
opacity: 0.9;
}
.send-btn:disabled {
opacity: 0.5;
cursor: not-allowed;
transform: none;
}
.controls {
display: flex;
gap: 10px;
margin-top: 12px;
flex-wrap: wrap;
}
.control-group {
display: flex;
align-items: center;
gap: 8px;
}
.control-group label {
font-size: 14px;
color: #6b7280;
}
.control-group input[type="range"] {
width: 120px;
}
.control-group input[type="number"] {
width: 60px;
padding: 4px 8px;
border: 1px solid #d1d5db;
border-radius: 6px;
text-align: center;
}
.sidebar {
flex: 1;
padding: 20px;
background: #f9fafb;
}
.info-card {
background: white;
border-radius: 12px;
padding: 20px;
margin-bottom: 20px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.info-card h3 {
color: #1f2937;
margin-bottom: 12px;
font-size: 1.1rem;
}
.info-item {
display: flex;
justify-content: space-between;
padding: 8px 0;
border-bottom: 1px solid #f3f4f6;
}
.info-item:last-child {
border-bottom: none;
}
.info-label {
color: #6b7280;
font-size: 14px;
}
.info-value {
color: #1f2937;
font-weight: 500;
}
.status-indicator {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 4px 12px;
border-radius: 20px;
font-size: 14px;
font-weight: 500;
}
.status-online {
background: #d1fae5;
color: #065f46;
}
.status-offline {
background: #fee2e2;
color: #991b1b;
}
.typing-indicator {
display: none;
padding: 12px;
text-align: center;
color: #6b7280;
}
.typing-dots {
display: inline-flex;
gap: 4px;
}
.typing-dots span {
width: 8px;
height: 8px;
background: #9ca3af;
border-radius: 50%;
animation: typing 1.4s infinite;
}
.typing-dots span:nth-child(2) {
animation-delay: 0.2s;
}
.typing-dots span:nth-child(3) {
animation-delay: 0.4s;
}
@keyframes typing {
0%, 60%, 100% {
transform: translateY(0);
opacity: 0.6;
}
30% {
transform: translateY(-6px);
opacity: 1;
}
}
.examples {
display: grid;
gap: 10px;
margin-top: 15px;
}
.example-btn {
padding: 10px 16px;
background: white;
border: 1px solid #e5e7eb;
border-radius: 8px;
text-align: left;
font-size: 14px;
color: #4b5563;
cursor: pointer;
transition: all 0.2s;
}
.example-btn:hover {
background: #f3f4f6;
border-color: #d1d5db;
transform: translateX(4px);
}
@media (max-width: 768px) {
.main-content {
flex-direction: column;
}
.chat-container {
border-right: none;
border-bottom: 1px solid #e5e7eb;
}
.message-content {
max-width: 85%;
}
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1><i class="fas fa-robot"></i> Youtu-2B 智能对话服务</h1>
<p>基于腾讯优图实验室的轻量化大语言模型,提供高效、智能的对话体验</p>
</div>
<div class="main-content">
<div class="chat-container">
<div class="chat-history" id="chatHistory">
<!-- 聊天记录将在这里显示 -->
<div class="message assistant">
<div class="avatar">AI</div>
<div class="message-content">
你好!我是Youtu-2B智能助手,基于腾讯优图实验室的轻量化大语言模型。我可以帮助你解答问题、编写代码、进行逻辑推理等。有什么可以帮你的吗?
</div>
</div>
</div>
<div class="typing-indicator" id="typingIndicator">
<div class="typing-dots">
<span></span>
<span></span>
<span></span>
</div>
正在思考...
</div>
<div class="input-area">
<div class="input-wrapper">
<textarea
id="messageInput"
placeholder="输入你的问题...(例如:帮我写一个Python快速排序算法)"
rows="3"
></textarea>
<button class="send-btn" id="sendBtn">
<i class="fas fa-paper-plane"></i> 发送
</button>
</div>
<div class="controls">
<div class="control-group">
<label for="temperature">温度:</label>
<input type="range" id="temperature" min="0.1" max="2.0" step="0.1" value="0.7">
<input type="number" id="temperatureValue" min="0.1" max="2.0" step="0.1" value="0.7">
</div>
<div class="control-group">
<label for="maxLength">最大长度:</label>
<input type="range" id="maxLength" min="100" max="2048" step="100" value="512">
<input type="number" id="maxLengthValue" min="100" max="2048" step="100" value="512">
</div>
<button class="send-btn" id="clearBtn" style="background: #6b7280; padding: 8px 16px;">
<i class="fas fa-trash"></i> 清空对话
</button>
</div>
</div>
</div>
<div class="sidebar">
<div class="info-card">
<h3><i class="fas fa-info-circle"></i> 服务状态</h3>
<div class="info-item">
<span class="info-label">模型状态:</span>
<span class="info-value">
<span class="status-indicator status-online" id="modelStatus">
<i class="fas fa-circle"></i> 在线
</span>
</span>
</div>
<div class="info-item">
<span class="info-label">模型名称:</span>
<span class="info-value">Youtu-LLM-2B</span>
</div>
<div class="info-item">
<span class="info-label">参数量:</span>
<span class="info-value">20亿</span>
</div>
<div class="info-item">
<span class="info-label">上下文长度:</span>
<span class="info-value">4096 tokens</span>
</div>
</div>
<div class="info-card">
<h3><i class="fas fa-bolt"></i> 性能指标</h3>
<div class="info-item">
<span class="info-label">响应时间:</span>
<span class="info-value" id="responseTime">-</span>
</div>
<div class="info-item">
<span class="info-label">生成速度:</span>
<span class="info-value" id="generationSpeed">-</span>
</div>
<div class="info-item">
<span class="info-label">总tokens:</span>
<span class="info-value" id="totalTokens">-</span>
</div>
</div>
<div class="info-card">
<h3><i class="fas fa-lightbulb"></i> 试试这些例子</h3>
<div class="examples">
<button class="example-btn" data-prompt="用Python写一个快速排序算法">
<i class="fas fa-code"></i> 写快速排序算法
</button>
<button class="example-btn" data-prompt="解释一下量子计算的基本原理">
<i class="fas fa-atom"></i> 解释量子计算
</button>
<button class="example-btn" data-prompt="帮我写一封工作汇报邮件">
<i class="fas fa-envelope"></i> 写工作汇报邮件
</button>
<button class="example-btn" data-prompt="什么是机器学习?用简单的例子说明">
<i class="fas fa-brain"></i> 解释机器学习
</button>
</div>
</div>
<div class="info-card">
<h3><i class="fas fa-cog"></i> 参数说明</h3>
<div class="info-item">
<span class="info-label">温度:</span>
<span class="info-value">控制随机性 (0.1-2.0)</span>
</div>
<div class="info-item">
<span class="info-label">值越小越确定</span>
<span class="info-label">值越大越有创意</span>
</div>
</div>
</div>
</div>
</div>
<script>
// DOM元素
const chatHistory = document.getElementById('chatHistory');
const messageInput = document.getElementById('messageInput');
const sendBtn = document.getElementById('sendBtn');
const clearBtn = document.getElementById('clearBtn');
const typingIndicator = document.getElementById('typingIndicator');
// 参数控制
const temperatureSlider = document.getElementById('temperature');
const temperatureValue = document.getElementById('temperatureValue');
const maxLengthSlider = document.getElementById('maxLength');
const maxLengthValue = document.getElementById('maxLengthValue');
// 状态显示
const responseTimeEl = document.getElementById('responseTime');
const generationSpeedEl = document.getElementById('generationSpeed');
const totalTokensEl = document.getElementById('totalTokens');
const modelStatusEl = document.getElementById('modelStatus');
// 同步滑块和输入框的值
temperatureSlider.addEventListener('input', () => {
temperatureValue.value = temperatureSlider.value;
});
temperatureValue.addEventListener('input', () => {
let value = parseFloat(temperatureValue.value);
if (value < 0.1) value = 0.1;
if (value > 2.0) value = 2.0;
temperatureValue.value = value;
temperatureSlider.value = value;
});
maxLengthSlider.addEventListener('input', () => {
maxLengthValue.value = maxLengthSlider.value;
});
maxLengthValue.addEventListener('input', () => {
let value = parseInt(maxLengthValue.value);
if (value < 100) value = 100;
if (value > 2048) value = 2048;
maxLengthValue.value = value;
maxLengthSlider.value = value;
});
// 发送消息
async function sendMessage() {
const message = messageInput.value.trim();
if (!message) return;
// 添加用户消息到聊天记录
addMessage(message, 'user');
// 清空输入框
messageInput.value = '';
// 禁用发送按钮
sendBtn.disabled = true;
// 显示正在输入指示器
typingIndicator.style.display = 'block';
scrollToBottom();
try {
// 获取当前参数
const temperature = parseFloat(temperatureValue.value);
const maxLength = parseInt(maxLengthValue.value);
// 发送请求到后端
const response = await fetch('/api/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
prompt: message,
temperature: temperature,
max_length: maxLength,
top_p: 0.9
})
});
const data = await response.json();
// 隐藏正在输入指示器
typingIndicator.style.display = 'none';
if (data.success) {
// 添加AI回复到聊天记录
addMessage(data.response, 'assistant');
// 更新性能指标
updatePerformanceMetrics(data);
} else {
// 显示错误信息
addMessage(`抱歉,出错了:${data.error || '未知错误'}`, 'assistant');
}
} catch (error) {
console.error('发送消息失败:', error);
typingIndicator.style.display = 'none';
addMessage('抱歉,网络连接出现问题,请稍后重试。', 'assistant');
} finally {
// 重新启用发送按钮
sendBtn.disabled = false;
// 聚焦到输入框
messageInput.focus();
}
}
// 添加消息到聊天记录
function addMessage(content, sender) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${sender}`;
const avatarDiv = document.createElement('div');
avatarDiv.className = 'avatar';
avatarDiv.textContent = sender === 'user' ? '你' : 'AI';
const contentDiv = document.createElement('div');
contentDiv.className = 'message-content';
// 如果是AI回复,保留换行格式
if (sender === 'assistant') {
contentDiv.innerHTML = content.replace(/\n/g, '<br>');
} else {
contentDiv.textContent = content;
}
messageDiv.appendChild(avatarDiv);
messageDiv.appendChild(contentDiv);
chatHistory.appendChild(messageDiv);
// 滚动到底部
scrollToBottom();
}
// 更新性能指标
function updatePerformanceMetrics(data) {
if (data.inference_time) {
responseTimeEl.textContent = `${data.inference_time}秒`;
}
if (data.tokens_per_second) {
generationSpeedEl.textContent = `${data.tokens_per_second} tokens/秒`;
}
if (data.total_tokens) {
totalTokensEl.textContent = data.total_tokens;
}
}
// 滚动到底部
function scrollToBottom() {
chatHistory.scrollTop = chatHistory.scrollHeight;
}
// 清空对话
function clearChat() {
// 保留第一条欢迎消息
const welcomeMessage = chatHistory.querySelector('.message.assistant');
chatHistory.innerHTML = '';
if (welcomeMessage) {
chatHistory.appendChild(welcomeMessage);
}
// 重置性能指标
responseTimeEl.textContent = '-';
generationSpeedEl.textContent = '-';
totalTokensEl.textContent = '-';
}
// 检查服务状态
async function checkServiceStatus() {
try {
const response = await fetch('/api/health');
const data = await response.json();
if (data.status === 'healthy') {
modelStatusEl.className = 'status-indicator status-online';
modelStatusEl.innerHTML = '<i class="fas fa-circle"></i> 在线';
} else {
modelStatusEl.className = 'status-indicator status-offline';
modelStatusEl.innerHTML = '<i class="fas fa-circle"></i> 离线';
}
} catch (error) {
modelStatusEl.className = 'status-indicator status-offline';
modelStatusEl.innerHTML = '<i class="fas fa-circle"></i> 离线';
}
}
// 事件监听
sendBtn.addEventListener('click', sendMessage);
messageInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
}
});
clearBtn.addEventListener('click', clearChat);
// 示例按钮点击事件
document.querySelectorAll('.example-btn').forEach(btn => {
btn.addEventListener('click', () => {
const prompt = btn.getAttribute('data-prompt');
messageInput.value = prompt;
messageInput.focus();
});
});
// 自动调整输入框高度
messageInput.addEventListener('input', function() {
this.style.height = 'auto';
this.style.height = (this.scrollHeight) + 'px';
});
// 定期检查服务状态
checkServiceStatus();
setInterval(checkServiceStatus, 30000); // 每30秒检查一次
// 初始聚焦到输入框
messageInput.focus();
</script>
</body>
</html>
4.2 使用Gradio快速构建界面(备选方案)
如果你想要更快速地搭建界面,也可以使用Gradio。这是一个更简单的方案:
# gradio_app.py
import gradio as gr
import requests
import json
from typing import Dict, Any
# API基础URL
API_BASE = "http://localhost:8080/api"
def chat_with_model(
message: str,
history: list,
temperature: float,
max_length: int
) -> str:
"""
与模型对话的Gradio接口函数
"""
try:
# 准备请求数据
data = {
"prompt": message,
"temperature": temperature,
"max_length": max_length,
"top_p": 0.9
}
# 发送请求
response = requests.post(
f"{API_BASE}/chat",
json=data,
timeout=60
)
if response.status_code == 200:
result = response.json()
if result.get("success"):
return result["response"]
else:
return f"错误: {result.get('error', '未知错误')}"
else:
return f"请求失败: HTTP {response.status_code}"
except requests.exceptions.RequestException as e:
return f"网络错误: {str(e)}"
except Exception as e:
return f"处理错误: {str(e)}"
def get_model_info() -> Dict[str, Any]:
"""获取模型信息"""
try:
response = requests.get(f"{API_BASE}/info", timeout=5)
if response.status_code == 200:
return response.json()
except:
pass
return {}
def create_gradio_interface():
"""创建Gradio界面"""
# 获取模型信息
model_info = get_model_info()
model_name = model_info.get("model", "Youtu-LLM-2B")
capabilities = model_info.get("capabilities", ["对话", "代码生成", "数学推理"])
# 创建界面
with gr.Blocks(
title=f"{model_name} 智能对话",
theme=gr.themes.Soft()
) as demo:
gr.Markdown(f"""
# 🤖 {model_name} 智能对话服务
基于腾讯优图实验室的轻量化大语言模型,支持:{", ".join(capabilities)}
**特点**:
- 🚀 轻量高效:仅需极少显存即可流畅运行
- 💬 智能对话:深度优化的中文对话能力
- 🛠️ 代码辅助:支持多种编程语言
- 🧮 数学推理:擅长逻辑推理和数学计算
**开始对话**:在下方输入你的问题,按Enter发送,Shift+Enter换行
""")
# 聊天历史
chatbot = gr.Chatbot(
label="对话历史",
height=500,
show_copy_button=True
)
# 状态信息
with gr.Row():
with gr.Column(scale=1):
status_text = gr.Textbox(
label="服务状态",
value="🟢 在线",
interactive=False
)
with gr.Column(scale=2):
model_info_text = gr.Textbox(
label="模型信息",
value=f"模型: {model_name} | 上下文长度: 4096 tokens",
interactive=False
)
# 输入区域
with gr.Row():
msg = gr.Textbox(
label="输入消息",
placeholder="输入你的问题...(例如:帮我写一个Python快速排序算法)",
lines=3,
scale=4
)
submit_btn = gr.Button("发送", variant="primary", scale=1)
# 参数控制
with gr.Accordion("高级参数", open=False):
with gr.Row():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="温度 (Temperature)",
info="控制回复的随机性,值越小越确定,值越大越有创意"
)
max_length = gr.Slider(
minimum=100,
maximum=2048,
value=512,
step=100,
label="最大生成长度",
info="控制生成文本的最大长度"
)
# 示例问题
with gr.Accordion("示例问题", open=True):
examples = gr.Examples(
examples=[
["用Python写一个快速排序算法"],
["解释一下量子计算的基本原理"],
["帮我写一封工作汇报邮件"],
["什么是机器学习?用简单的例子说明"],
["写一个关于人工智能的短故事"],
["如何提高学习效率?给出具体建议"]
],
inputs=msg,
label="点击使用示例"
)
# 控制按钮
with gr.Row():
clear_btn = gr.Button("清空对话", variant="secondary")
retry_btn = gr.Button("重试上次", variant="secondary")
# 系统信息
with gr.Accordion("系统信息", open=False):
gr.Markdown(f"""
### 技术规格
- **模型**: {model_name}
- **参数量**: 20亿
- **上下文长度**: 4096 tokens
- **支持语言**: 中文、英文
- **推理后端**: Flask + Transformers
### API接口
- `POST /api/chat` - 聊天接口
- `POST /api/batch_chat` - 批量聊天接口
- `GET /api/health` - 健康检查
- `GET /api/info` - 模型信息
### 使用提示
1. 问题描述越详细,回答越准确
2. 可以要求模型以特定格式回复(如代码、列表、表格等)
3. 温度参数影响创造性,编程任务建议使用较低温度(0.1-0.5)
4. 对话任务建议使用中等温度(0.7-1.0)
""")
# 事件处理
def respond(message, chat_history, temp, max_len):
"""处理用户输入并生成回复"""
if not message.strip():
return "", chat_history
# 添加用户消息到历史
chat_history.append((message, None))
# 生成回复
response = chat_with_model(message, chat_history, temp, max_len)
# 更新最后一条消息的回复
chat_history[-1] = (message, response)
return "", chat_history
# 连接事件
submit_btn.click(
respond,
[msg, chatbot, temperature, max_length],
[msg, chatbot]
)
msg.submit(
respond,
[msg, chatbot, temperature, max_length],
[msg, chatbot]
)
clear_btn.click(lambda: [], None, chatbot)
def retry_last(chat_history):
"""重试最后一条消息"""
if not chat_history:
return chat_history
# 获取最后一条用户消息
last_message = chat_history[-1][0] if chat_history[-1][1] is None else None
if last_message:
# 重新生成回复
response = chat_with_model(
last_message,
chat_history[:-1],
temperature.value,
max_length.value
)
chat_history[-1] = (last_message, response)
return chat_history
retry_btn.click(retry_last, chatbot, chatbot)
# 定期更新状态
def update_status():
"""更新服务状态"""
try:
response = requests.get(f"{API_BASE}/health", timeout=5)
if response.status_code == 200:
return "🟢 在线"
else:
return "🔴 离线"
except:
return "🔴 离线"
# 状态自动更新
demo.load(update_status, None, status_text, every=30)
return demo
if __name__ == "__main__":
# 创建并启动Gradio界面
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
5. 部署与优化建议
5.1 生产环境部署
现在我们已经有了完整的服务代码,接下来看看如何部署到生产环境。
使用Gunicorn部署Flask应用:
首先安装Gunicorn:
pip install gunicorn
创建Gunicorn配置文件:
# gunicorn_config.py
import multiprocessing
# 服务器配置
bind = "0.0.0.0:8080"
workers = multiprocessing.cpu_count() * 2 + 1
worker_class = "sync"
worker_connections = 1000
timeout = 120
keepalive = 2
# 日志配置
accesslog = "./logs/access.log"
errorlog = "./logs/error.log"
loglevel = "info"
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
# 进程配置
daemon = False
pidfile = "./gunicorn.pid"
umask = 0
user = None
group = None
tmp_upload_dir = None
# 服务器钩子
def post_fork(server, worker):
server.log.info("Worker spawned (pid: %s)", worker.pid)
def pre_fork(server, worker):
pass
def pre_exec(server):
server.log.info("Forked child, re-executing.")
def when_ready(server):
server.log.info("Server is ready. Spawning workers")
def worker_int(worker):
worker.log.info("worker received INT or QUIT signal")
def worker_abort(worker):
worker.log.info("worker received SIGABRT signal")
启动服务:
# 使用Gunicorn启动
gunicorn -c gunicorn_config.py app:app
# 或者使用简单命令
gunicorn -w 4 -b 0.0.0.0:8080 app:app
使用Docker容器化部署:
创建Dockerfile:
# Dockerfile
FROM python:3.9-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
# 复制应用代码
COPY . .
# 创建日志目录
RUN mkdir -p logs
# 下载模型(可以在构建时下载,或运行时下载)
# 这里假设模型已经预下载到models目录
# 如果没有,可以取消下面的注释
# RUN python download_model.py
# 暴露端口
EXPOSE 8080
# 启动命令
CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:8080", "app:app"]
创建docker-compose.yml:
# docker-compose.yml
version: '3.8'
services:
youtu-2b-service:
build: .
container_name: youtu-2b-service
ports:
- "8080:8080"
volumes:
- ./models:/app/models
- ./logs:/app/logs
environment:
- PYTHONUNBUFFERED=1
- TZ=Asia/Shanghai
restart: unless-stopped
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/api/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
构建和运行:
# 构建镜像
docker-compose build
# 启动服务
docker-compose up -d
# 查看日志
docker-compose logs -f
# 停止服务
docker-compose down
5.2 性能优化建议
模型加载优化:
# core/model_loader_optimized.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from contextlib import contextmanager
import logging
logger = logging.getLogger(__name__)
class OptimizedModelLoader:
"""优化版的模型加载器"""
def __init__(self, model_path: str = "./models/youtu-2b"):
self.model_path = model_path
self.tokenizer = None
self.model = None
self.device = None
def load_model_with_optimizations(self, use_gpu: bool = True):
"""带优化的模型加载"""
try:
logger.info("开始加载优化版模型...")
# 确定设备
if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
torch.backends.cudnn.benchmark = True # 启用cudnn自动优化
torch.backends.cuda.matmul.allow_tf32 = True # 允许TF32
else:
self.device = torch.device("cpu")
# 启用CPU优化
torch.set_num_threads(torch.get_num_threads())
# 加载分词器(使用缓存)
logger.info("加载分词器...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True,
cache_dir="./cache" # 指定缓存目录
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# 模型加载配置
load_config = {
"trust_remote_code": True,
"cache_dir": "./cache",
"low_cpu_mem_usage": True, # 减少CPU内存使用
}
# 根据设备选择数据类型和加载方式
if self.device.type == "cuda":
load_config.update({
"torch_dtype": torch.float16, # 使用半精度
"device_map": "auto", # 自动分配到多个GPU
})
else:
load_config.update({
"torch_dtype": torch.float32,
})
# 加载模型
logger.info("加载模型(优化版)...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
**load_config
)
# 如果使用CPU,启用更多优化
if self.device.type == "cpu":
self.model = self.model.to(self.device)
# 启用CPU推理优化
try:
import intel_extension_for_pytorch as ipex
self.model = ipex.optimize(self.model)
logger.info("已启用Intel PyTorch扩展优化")
except ImportError:
logger.info("未安装Intel PyTorch扩展,使用标准CPU模式")
# 设置为评估模式
self.model.eval()
# 启用更好的推理模式
if hasattr(self.model, "prepare_for_inference"):
self.model.prepare_for_inference()
logger.info("优化版模型加载完成!")
return True, "模型加载成功"
except Exception as e:
error_msg = f"模型加载失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
@contextmanager
def inference_context(self):
"""推理上下文管理器,优化内存使用"""
try:
if self.device.type == "cuda":
# 清空CUDA缓存
torch.cuda.empty_cache()
torch.cuda.synchronize()
yield self.model, self.tokenizer
finally:
if self.device.type == "cuda":
torch.cuda.empty_cache()
def optimize_for_inference(self):
"""进一步优化推理性能"""
if self.model is None:
return
logger.info("应用推理优化...")
# 启用更好的KV缓存(如果支持)
if hasattr(self.model.config, "use_cache"):
self.model.config.use_cache = True
# 启用梯度检查点(节省内存)
if hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable()
# 如果使用CPU,应用更多优化
if self.device.type == "cpu":
# 设置线程数
torch.set_num_threads(min(4, torch.get_num_threads()))
# 启用更好的内存布局
try:
self.model 更多推荐
所有评论(0)