1、推理主函数:run.py

import glob
import os
import shutil
from functools import cmp_to_key
from pathlib import Path
from tempfile import TemporaryDirectory
import random
import jukemirlib
import numpy as np
import torch
from data.slice import slice_audio
from log.EDGE import EDGE
from data.audio_extraction.baseline_features import extract as baseline_extract
from data.audio_extraction.jukebox_features import extract as juke_extract
from data.audio_extraction.jukebox_features import test_jukebox as test_jukeboxs
import os,yaml
from flask import Flask, jsonify
import logging
import logging.handlers
from flask import Flask, render_template, request, session, send_file, make_response, redirect
import warnings
import subprocess
from werkzeug.utils import secure_filename
from pydub import AudioSegment
import time

# 忽略所有警告
warnings.filterwarnings("ignore")

# api接口
app = Flask(__name__)

UPLOAD_PATH = 'out_data/input'
os.makedirs(UPLOAD_PATH,exist_ok=True)

# 创建 edge 类
class EDGE_api:
    def __init__(self,data):
        # 读取所有参数
        self.key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
        self.stringintkey = cmp_to_key(self.stringintcmp)
        self.feature_func = juke_extract if data['feature_type'] == "jukebox" else baseline_extract
        self.sample_length = data['out_length']
        self.sample_size = int(self.sample_length / 2.5) - 1
        self.model = EDGE(data['feature_type'], data['checkpoint'])
        self.model.eval()
        self.render_dir = data['render_dir']
        self.motion_save_dir = data['motion_save_dir']
        self.cache_features = data['cache_features']
        self.conda_name = data['conda_name']
        self.fbx_save_dir = data['fbx_save_dir']
        self.bvh_save_dir = data['bvh_save_dir']
        self.temp_save_dir = data['temp_save_dir']
        self.no_render = data['no_render']
        self.bvh = data['bvh']
        self.feature_cache_dir = data['feature_cache_dir']
        # 创建log文件
        self.logger = self.logging_save(data['logging_path'],data['when'],data['interval'],data['backupCount'])
        self.logger.info('starting EDGE')
        self.logger.info(data)

    # 定义log 参数
    def logging_save(self,save_path, when, interval, backupCount):
        # 创建一个logger
        logger = logging.getLogger('logger')
        logger.setLevel(logging.DEBUG)

        # 创建一个handler,用于写入日志文件,每天创建一个新的日志文件
        handler = logging.handlers.TimedRotatingFileHandler(save_path, when=when, interval=interval,
                                                            backupCount=backupCount)

        # 定义handler的输出格式
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)

        # 给logger添加handler
        logger.addHandler(handler)

        return logger

    def stringintcmp(self,a,b):
        aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
        ka, kb = self.key_func(a), self.key_func(b)
        if aa < bb:
            return -1
        if aa > bb:
            return 1
        if ka < kb:
            return -1
        if ka > kb:
            return 1
        return 0


    # 执行推理方法
    def test(self,parm):
        self.logger.info('start... {}'.format(parm))

        cache_features = self.cache_features if parm['cache_features'] is None else True
        render_dir = self.render_dir if parm['render_dir'] is None else parm['render_dir']
        no_render = self.no_render if parm['no_render'] is None else True
        wav_file = parm['wav_file']
        feature_cache_dir = self.feature_cache_dir if parm['feature_cache_dir'] is None else parm['feature_cache_dir']
        motion_save_dir = self.motion_save_dir if parm['motion_save_dir'] is None else parm['motion_save_dir']
        if cache_features:
            songname = os.path.splitext(os.path.basename(wav_file))[0]
            save_dir = os.path.join(feature_cache_dir, songname)
            Path(save_dir).mkdir(parents=True, exist_ok=True)
            dirname = save_dir
        else:
            temp_dir = TemporaryDirectory()
            dirname = temp_dir.name

        _,all_len = slice_audio(wav_file, 2.5, 5.0, dirname)
        file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=self.stringintkey)

        if self.sample_length < 1:
            sample_size = int(all_len / 2.5) - 1
        else:
            sample_size = self.sample_length

        # randomly sample a chunk of length at most sample_size

        rand_idx = random.randint(0, len(file_list) - sample_size)
        cond_list = []
        for idx, file in enumerate(file_list):
            if (not cache_features) and (not (rand_idx <= idx < rand_idx + sample_size)):
                continue
            reps, _ = self.feature_func(file)
            # save reps
            if cache_features:
                featurename = os.path.splitext(file)[0] + ".npy"
                np.save(featurename, reps)
            if rand_idx <= idx < rand_idx + sample_size:
                cond_list.append(reps)
        cond_list = torch.from_numpy(np.array(cond_list))
        data_tuple = None, cond_list, file_list[rand_idx : rand_idx + sample_size]
        self.model.render_sample(
            data_tuple, "test", render_dir, render_count=-1, fk_out=motion_save_dir, render=not no_render
        )
        torch.cuda.empty_cache()
        temp_dir.cleanup()
        name = 'test_' + wav_file.split('/')[-1].replace('.wav','.pkl').replace('.mp3','.pkl')

        outpath = os.path.join(motion_save_dir,name)

        self.logger.info('end - out path : {}'.format(outpath))
        os.makedirs(self.fbx_save_dir,exist_ok=True)
        os.makedirs(self.bvh_save_dir,exist_ok=True)
        os.makedirs(self.temp_save_dir,exist_ok=True)

        fbx_save_temp_path = self.temp_save_dir + '/{}/'.format(wav_file.split('/')[-1].replace('.wav','').replace('.mp3',''))
        os.makedirs(fbx_save_temp_path,exist_ok=True)

        # 调用env2环境中的脚本
        self.logger.info('start - pkl2fbx : {}'.format(outpath))
        # 调用 pkl 转 fbx 方法
        subprocess.run(["conda", "run", "-n", self.conda_name, "python", "smpl2fbx/Convert.py",'--input_dir', outpath ,'--output_dir', fbx_save_temp_path])

        if os.path.exists(self.fbx_save_dir + '/' + name.replace('.pkl', '.fbx')):
            os.remove(self.fbx_save_dir + '/' + name.replace('.pkl', '.fbx'))
            # 移动文件

        shutil.move(fbx_save_temp_path + '/' + name.replace('.pkl', '.fbx'),self.fbx_save_dir + '/' + name.replace('.pkl', '.fbx'))

        shutil.rmtree(fbx_save_temp_path)

        self.logger.info('end - pkl2fbx : {}'.format(self.fbx_save_dir + '/' + name.replace('.pkl','.fbx')))

        self.logger.info('start - pkl2bvh : {}'.format(outpath))

        subprocess.run(["conda", "run", "-n", 'zrx_edge3', "python", "smpl2bvh/smpl2bvh_edge_new.py", '--poses', outpath, '--bvh_file', self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')])

        self.logger.info('end - pkl2bvh: {}'.format(self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')))

        self.logger.info('result: {}'.format({'output_pkl': outpath,
                        'output_fbx': self.fbx_save_dir + '/' + name.replace('.pkl','.fbx'),
                        'output_bvh': self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')}))

        if self.bvh:
            return self.bvh_save_dir + '/' + name.replace('.pkl', '_f.bvh')
        else:
            return self.fbx_save_dir + '/' + name.replace('.pkl','.fbx')

        # return jsonify({'output_pkl': outpath,
        #                 'output_fbx': self.fbx_save_dir + '/' + name.replace('.pkl','.fbx'),
        #                 'output_bvh': self.bvh_save_dir + '/' + name.replace('.pkl', '.bvh')
        #                 })


def generate_unique_filename(extension):
    timestamp = int(time.time() * 1000)  # 获取毫秒级时间戳
    random_part = random.randint(1000, 9999)  # 生成一个四位数随机数
    return f"{timestamp}_{random_part}.{extension}"

# 读取固定参数
with open("log/config.yaml", "r", encoding="utf-8") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)['test_options']

os.environ["CUDA_VISIBLE_DEVICES"] = config['cuda']
edge = EDGE_api(config)

# jukemirlib.setup_models(device="cpu")
test_jukeboxs(config['features_path'])

# 判断上传的文件是否是允许的后缀
def allowed_file(filename):
    return "." in filename and filename.rsplit('.', 1)[1].lower() in set(['wav','mp3'])


@app.route('/upload/', methods=['GET', 'POST'])
def upload():
    if request.method == 'GET':
        return render_template('upload.html')
    else:
        file = request.files.get('pic')  # 获取文件
        # 上传空文件(无文件)
        if file.filename == '':
            return redirect(request.url)

        if file and allowed_file(file.filename):
            # filename = secure_filename(file.filename)  # 用这个函数确定文件名称是否是安全 (注意:中文不能识别)
            filename = generate_unique_filename(file.filename.split('.')[-1])
            outpath = os.path.join(UPLOAD_PATH, filename)
            file.save(outpath)  # 保存文件

            if outpath.split('.')[-1] != 'wav':
                sound = AudioSegment.from_mp3(outpath)  # 从MP3读取音频
                outpath = outpath.replace('.' + outpath.split('.')[-1],'.wav')
                sound.export(outpath, format="wav") # 导出为WAV格式

            parm = {}
            keys = ['cache_features','render_dir','no_render','wav_file','feature_cache_dir','motion_save_dir']
            for key in keys:
                parm[key] = request.args.get(key)

            parm['file_name'] = file.filename
            parm['wav_file'] = outpath
            return send_file(edge.test(parm), as_attachment=True)
        else:
            return redirect(request.url)


if __name__ == "__main__":
    app.run(host=config['host'], port=config['port'], debug=False)

2、gunicorn 参数文件 gunicorn.py

import os
import time
gpu_assignments = {
    '1':'2',
    '2':'3',
    '3':'4',
    '4':'5',
    '5':'6',
    '6':'7',
    '7':'8',
    '8':'9',
    '9':'10',
}

try:
    import pynvml
    pynvml.nvmlInit()
    gpuDeviceCount = pynvml.nvmlDeviceGetCount()
except:
    gpuDeviceCount = 1

gpuDevicePool = []

def pre_fork(server, worker):
    try:
        gid = gpuDevicePool.pop(0)
    except:
        gid = (worker.age - 1) % gpuDeviceCount
    worker.gid = gid


def post_fork(server, worker):
    time.sleep(worker.age % server.cfg.workers)
    worker.gid = gpu_assignments[worker.age]
    os.environ['CUDA_VISIBLE_DEVICES'] = str(worker.gid)
    server.log.info(f'worker(age:{worker.age}, pid:{worker.pid}, cuda:{worker.gid})')

def child_exit(server, worker):
    gpuDevicePool.append(worker.gid)

3、执行代码, 开启 6个进程、

gunicorn -c gunicorn.py run_bvh:app -w 6 --timeout 500 --bind 0.0.0.0:8080
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐