File size: 5,090 Bytes
4de73fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import argparse
from text import text_to_sequence
import numpy as np
from scipy.io import wavfile
import torch
import json
import commons
import utils
import sys
import pathlib
from flask import Flask, request
import threading
import onnxruntime as ort
import time
from pydub import AudioSegment
import io
import os
from transformers import AutoTokenizer, AutoModel
import tkinter as tk
from tkinter import scrolledtext
from scipy.io.wavfile import write
def get_args():
parser = argparse.ArgumentParser(description='inference')
parser.add_argument('--onnx_model', default = './moe/model.onnx')
parser.add_argument('--cfg', default="./moe/config_v.json")
parser.add_argument('--outdir', default="./moe",
help='ouput folder')
parser.add_argument('--audio',
type=str,
help='你要替换的音频文件的,假设这些音频文件为temp1、temp2、temp3......',
default = 'D:/app_develop/live2d_whole/2010002/sounds/temp.wav')
parser.add_argument('--ChatGLM',default = "./moe",
help='https://github.com/THUDM/ChatGLM-6B')
args = parser.parse_args()
return args
def to_numpy(tensor: torch.Tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad \
else tensor.detach().numpy()
def get_symbols_from_json(path):
import os
assert os.path.isfile(path)
with open(path, 'r') as f:
data = json.load(f)
return data['symbols']
args = get_args()
symbols = get_symbols_from_json(args.cfg)
phone_dict = {
symbol: i for i, symbol in enumerate(symbols)
}
hps = utils.get_hparams_from_file(args.cfg)
ort_sess = ort.InferenceSession(args.onnx_model)
def is_japanese(string):
for ch in string:
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
return True
return False
def infer(text):
#选择你想要的角色
sid = 7
text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
#seq = text_to_sequence(text, symbols=hps.symbols, cleaner_names=hps.data.text_cleaners)
seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners)
if hps.data.add_blank:
seq = commons.intersperse(seq, 0)
with torch.no_grad():
x = np.array([seq], dtype=np.int64)
x_len = np.array([x.shape[1]], dtype=np.int64)
sid = np.array([sid], dtype=np.int64)
scales = np.array([0.667, 0.7, 1], dtype=np.float32)
scales.resize(1, 3)
ort_inputs = {
'input': x,
'input_lengths': x_len,
'scales': scales,
'sid': sid
}
t1 = time.time()
audio = np.squeeze(ort_sess.run(None, ort_inputs))
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
audio = np.clip(audio, -32767.0, 32767.0)
bytes_wav = bytes()
byte_io = io.BytesIO(bytes_wav)
wavfile.write(args.audio + '.wav',hps.data.sampling_rate, audio.astype(np.int16))
i = 0
while i < 19:
i +=1
cmd = 'ffmpeg -y -i ' + args.audio + '.wav' + ' -ar 44100 '+ args.audio.replace('temp','temp'+str(i))
os.system(cmd)
t2 = time.time()
print("推理耗时:",(t2 - t1),"s")
return text
tokenizer = AutoTokenizer.from_pretrained(args.ChatGLM, trust_remote_code=True)
#8G GPU
model = AutoModel.from_pretrained(args.ChatGLM, trust_remote_code=True).half().quantize(4).cuda()
history = []
def send_message():
global history
message = input_box.get("1.0", "end-1c") # 获取用户输入的文本
t1 = time.time()
if message == 'clear':
history = []
else:
response, new_history = model.chat(tokenizer, message, history)
response = response.replace(" ",'').replace("\n",'.')
text = infer(response)
text = text.replace('[JA]','').replace('[ZH]','')
chat_box.configure(state='normal') # 配置聊天框为可写状态
chat_box.insert(tk.END, "You: " + message + "\n") # 在聊天框中显示用户输入的文本
chat_box.insert(tk.END, "Tamao: " + text + "\n") # 在聊天框中显示 chatbot 的回复
chat_box.configure(state='disabled') # 配置聊天框为只读状态
input_box.delete("1.0", tk.END) # 清空输入框
t2 = time.time()
print("总共耗时:",(t2 - t1),"s")
root = tk.Tk()
root.title("Tamao")
# 创建聊天框
chat_box = scrolledtext.ScrolledText(root, width=50, height=10)
chat_box.configure(state='disabled') # 聊天框一开始是只读状态
chat_box.pack(side=tk.TOP, fill=tk.BOTH, padx=10, pady=10, expand=True)
# 创建输入框和发送按钮
input_frame = tk.Frame(root)
input_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=10)
input_box = tk.Text(input_frame, height=3, width=50) # 设置输入框宽度为50
input_box.pack(side=tk.LEFT, fill=tk.X, padx=10, expand=True)
send_button = tk.Button(input_frame, text="Send", command=send_message)
send_button.pack(side=tk.RIGHT, padx=10)
# 运行主程序
root.mainloop() |