gunicorn+flask+EDGE部署(多进程、多gpu)
2、gunicorn 参数文件 gunicorn.py。3、执行代码, 开启 6个进程、1、推理主函数:run.py。
·
1、推理主函数:run.py
import glob
import os
import shutil
from functools import cmp_to_key
from pathlib import Path
from tempfile import TemporaryDirectory
import random
import jukemirlib
import numpy as np
import torch
from data.slice import slice_audio
from log.EDGE import EDGE
from data.audio_extraction.baseline_features import extract as baseline_extract
from data.audio_extraction.jukebox_features import extract as juke_extract
from data.audio_extraction.jukebox_features import test_jukebox as test_jukeboxs
import os,yaml
from flask import Flask, jsonify
import logging
import logging.handlers
from flask import Flask, render_template, request, session, send_file, make_response, redirect
import warnings
import subprocess
from werkzeug.utils import secure_filename
from pydub import AudioSegment
import time
# 忽略所有警告
warnings.filterwarnings("ignore")
# api接口
app = Flask(__name__)
UPLOAD_PATH = 'out_data/input'
os.makedirs(UPLOAD_PATH,exist_ok=True)
# 创建 edge 类
class EDGE_api:
def __init__(self,data):
# 读取所有参数
self.key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
self.stringintkey = cmp_to_key(self.stringintcmp)
self.feature_func = juke_extract if data['feature_type'] == "jukebox" else baseline_extract
self.sample_length = data['out_length']
self.sample_size = int(self.sample_length / 2.5) - 1
self.model = EDGE(data['feature_type'], data['checkpoint'])
self.model.eval()
self.render_dir = data['render_dir']
self.motion_save_dir = data['motion_save_dir']
self.cache_features = data['cache_features']
self.conda_name = data['conda_name']
self.fbx_save_dir = data['fbx_save_dir']
self.bvh_save_dir = data['bvh_save_dir']
self.temp_save_dir = data['temp_save_dir']
self.no_render = data['no_render']
self.bvh = data['bvh']
self.feature_cache_dir = data['feature_cache_dir']
# 创建log文件
self.logger = self.logging_save(data['logging_path'],data['when'],data['interval'],data['backupCount'])
self.logger.info('starting EDGE')
self.logger.info(data)
# 定义log 参数
def logging_save(self,save_path, when, interval, backupCount):
# 创建一个logger
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
# 创建一个handler,用于写入日志文件,每天创建一个新的日志文件
handler = logging.handlers.TimedRotatingFileHandler(save_path, when=when, interval=interval,
backupCount=backupCount)
# 定义handler的输出格式
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
# 给logger添加handler
logger.addHandler(handler)
return logger
def stringintcmp(self,a,b):
aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
ka, kb = self.key_func(a), self.key_func(b)
if aa < bb:
return -1
if aa > bb:
return 1
if ka < kb:
return -1
if ka > kb:
return 1
return 0
# 执行推理方法
def test(self,parm):
self.logger.info('start... {}'.format(parm))
cache_features = self.cache_features if parm['cache_features'] is None else True
render_dir = self.render_dir if parm['render_dir'] is None else parm['render_dir']
no_render = self.no_render if parm['no_render'] is None else True
wav_file = parm['wav_file']
feature_cache_dir = self.feature_cache_dir if parm['feature_cache_dir'] is None else parm['feature_cache_dir']
motion_save_dir = self.motion_save_dir if parm['motion_save_dir'] is None else parm['motion_save_dir']
if cache_features:
songname = os.path.splitext(os.path.basename(wav_file))[0]
save_dir = os.path.join(feature_cache_dir, songname)
Path(save_dir).mkdir(parents=True, exist_ok=True)
dirname = save_dir
else:
temp_dir = TemporaryDirectory()
dirname = temp_dir.name
_,all_len = slice_audio(wav_file, 2.5, 5.0, dirname)
file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=self.stringintkey)
if self.sample_length < 1:
sample_size = int(all_len / 2.5) - 1
else:
sample_size = self.sample_length
# randomly sample a chunk of length at most sample_size
rand_idx = random.randint(0, len(file_list) - sample_size)
cond_list = []
for idx, file in enumerate(file_list):
if (not cache_features) and (not (rand_idx <= idx < rand_idx + sample_size)):
continue
reps, _ = self.feature_func(file)
# save reps
if cache_features:
featurename = os.path.splitext(file)[0] + ".npy"
np.save(featurename, reps)
if rand_idx <= idx < rand_idx + sample_size:
cond_list.append(reps)
cond_list = torch.from_numpy(np.array(cond_list))
data_tuple = None, cond_list, file_list[rand_idx : rand_idx + sample_size]
self.model.render_sample(
data_tuple, "test", render_dir, render_count=-1, fk_out=motion_save_dir, render=not no_render
)
torch.cuda.empty_cache()
temp_dir.cleanup()
name = 'test_' + wav_file.split('/')[-1].replace('.wav','.pkl').replace('.mp3','.pkl')
outpath = os.path.join(motion_save_dir,name)
self.logger.info('end - out path : {}'.format(outpath))
os.makedirs(self.fbx_save_dir,exist_ok=True)
os.makedirs(self.bvh_save_dir,exist_ok=True)
os.makedirs(self.temp_save_dir,exist_ok=True)
fbx_save_temp_path = self.temp_save_dir + '/{}/'.format(wav_file.split('/')[-1].replace('.wav','').replace('.mp3',''))
os.makedirs(fbx_save_temp_path,exist_ok=True)
# 调用env2环境中的脚本
self.logger.info('start - pkl2fbx : {}'.format(outpath))
# 调用 pkl 转 fbx 方法
subprocess.run(["conda", "run", "-n", self.conda_name, "python", "smpl2fbx/Convert.py",'--input_dir', outpath ,'--output_dir', fbx_save_temp_path])
if os.path.exists(self.fbx_save_dir + '/' + name.replace('.pkl', '.fbx')):
os.remove(self.fbx_save_dir + '/' + name.replace('.pkl', '.fbx'))
# 移动文件
shutil.move(fbx_save_temp_path + '/' + name.replace('.pkl', '.fbx'),self.fbx_save_dir + '/' + name.replace('.pkl', '.fbx'))
shutil.rmtree(fbx_save_temp_path)
self.logger.info('end - pkl2fbx : {}'.format(self.fbx_save_dir + '/' + name.replace('.pkl','.fbx')))
self.logger.info('start - pkl2bvh : {}'.format(outpath))
subprocess.run(["conda", "run", "-n", 'zrx_edge3', "python", "smpl2bvh/smpl2bvh_edge_new.py", '--poses', outpath, '--bvh_file', self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')])
self.logger.info('end - pkl2bvh: {}'.format(self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')))
self.logger.info('result: {}'.format({'output_pkl': outpath,
'output_fbx': self.fbx_save_dir + '/' + name.replace('.pkl','.fbx'),
'output_bvh': self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')}))
if self.bvh:
return self.bvh_save_dir + '/' + name.replace('.pkl', '_f.bvh')
else:
return self.fbx_save_dir + '/' + name.replace('.pkl','.fbx')
# return jsonify({'output_pkl': outpath,
# 'output_fbx': self.fbx_save_dir + '/' + name.replace('.pkl','.fbx'),
# 'output_bvh': self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')
# })
def generate_unique_filename(extension):
timestamp = int(time.time() * 1000) # 获取毫秒级时间戳
random_part = random.randint(1000, 9999) # 生成一个四位数随机数
return f"{timestamp}_{random_part}.{extension}"
# 读取固定参数
with open("log/config.yaml", "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.FullLoader)['test_options']
os.environ["CUDA_VISIBLE_DEVICES"] = config['cuda']
edge = EDGE_api(config)
# jukemirlib.setup_models(device="cpu")
test_jukeboxs(config['features_path'])
# 判断上传的文件是否是允许的后缀
def allowed_file(filename):
return "." in filename and filename.rsplit('.', 1)[1].lower() in set(['wav','mp3'])
@app.route('/upload/', methods=['GET', 'POST'])
def upload():
if request.method == 'GET':
return render_template('upload.html')
else:
file = request.files.get('pic') # 获取文件
# 上传空文件(无文件)
if file.filename == '':
return redirect(request.url)
if file and allowed_file(file.filename):
# filename = secure_filename(file.filename) # 用这个函数确定文件名称是否是安全 (注意:中文不能识别)
filename = generate_unique_filename(file.filename.split('.')[-1])
outpath = os.path.join(UPLOAD_PATH, filename)
file.save(outpath) # 保存文件
if outpath.split('.')[-1] != 'wav':
sound = AudioSegment.from_mp3(outpath) # 从MP3读取音频
outpath = outpath.replace('.' + outpath.split('.')[-1],'.wav')
sound.export(outpath, format="wav") # 导出为WAV格式
parm = {}
keys = ['cache_features','render_dir','no_render','wav_file','feature_cache_dir','motion_save_dir']
for key in keys:
parm[key] = request.args.get(key)
parm['file_name'] = file.filename
parm['wav_file'] = outpath
return send_file(edge.test(parm), as_attachment=True)
else:
return redirect(request.url)
if __name__ == "__main__":
app.run(host=config['host'], port=config['port'], debug=False)
2、gunicorn 参数文件 gunicorn.py
import os
import time
gpu_assignments = {
'1':'2',
'2':'3',
'3':'4',
'4':'5',
'5':'6',
'6':'7',
'7':'8',
'8':'9',
'9':'10',
}
try:
import pynvml
pynvml.nvmlInit()
gpuDeviceCount = pynvml.nvmlDeviceGetCount()
except:
gpuDeviceCount = 1
gpuDevicePool = []
def pre_fork(server, worker):
try:
gid = gpuDevicePool.pop(0)
except:
gid = (worker.age - 1) % gpuDeviceCount
worker.gid = gid
def post_fork(server, worker):
time.sleep(worker.age % server.cfg.workers)
worker.gid = gpu_assignments[worker.age]
os.environ['CUDA_VISIBLE_DEVICES'] = str(worker.gid)
server.log.info(f'worker(age:{worker.age}, pid:{worker.pid}, cuda:{worker.gid})')
def child_exit(server, worker):
gpuDevicePool.append(worker.gid)
3、执行代码, 开启 6个进程、
gunicorn -c gunicorn.py run_bvh:app -w 6 --timeout 500 --bind 0.0.0.0:8080
更多推荐
已为社区贡献2条内容
所有评论(0)