File size: 5,090 Bytes
4de73fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
from text import text_to_sequence
import numpy as np
from scipy.io import wavfile
import torch
import json
import commons
import utils
import sys
import pathlib
from flask import Flask, request
import threading
import onnxruntime as ort
import time
from pydub import AudioSegment
import io
import os
from transformers import AutoTokenizer, AutoModel
import tkinter as tk
from tkinter import scrolledtext
from scipy.io.wavfile import write
def get_args():
    parser = argparse.ArgumentParser(description='inference')
    parser.add_argument('--onnx_model', default = './moe/model.onnx')
    parser.add_argument('--cfg', default="./moe/config_v.json")
    parser.add_argument('--outdir', default="./moe",
                        help='ouput folder')
    parser.add_argument('--audio',
                    type=str,
                    help='你要替换的音频文件的,假设这些音频文件为temp1、temp2、temp3......',
                    default = 'D:/app_develop/live2d_whole/2010002/sounds/temp.wav')
    parser.add_argument('--ChatGLM',default = "./moe",
                        help='https://github.com/THUDM/ChatGLM-6B')
    args = parser.parse_args()
    return args

def to_numpy(tensor: torch.Tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad \
        else tensor.detach().numpy()

def get_symbols_from_json(path):
    import os
    assert os.path.isfile(path)
    with open(path, 'r') as f:
        data = json.load(f)
    return data['symbols']

args = get_args()
symbols = get_symbols_from_json(args.cfg)
phone_dict = {
        symbol: i for i, symbol in enumerate(symbols)
    }
hps = utils.get_hparams_from_file(args.cfg)
ort_sess = ort.InferenceSession(args.onnx_model)

def is_japanese(string):
        for ch in string:
            if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
                return True
        return False 

def infer(text):
    #选择你想要的角色
    sid = 7
    text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
    #seq = text_to_sequence(text, symbols=hps.symbols, cleaner_names=hps.data.text_cleaners)
    seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners)
    if hps.data.add_blank:
        seq = commons.intersperse(seq, 0)
    with torch.no_grad():
        x = np.array([seq], dtype=np.int64)
        x_len = np.array([x.shape[1]], dtype=np.int64)
        sid = np.array([sid], dtype=np.int64)
        scales = np.array([0.667, 0.7, 1], dtype=np.float32)
        scales.resize(1, 3)
        ort_inputs = {
                    'input': x,
                    'input_lengths': x_len,
                    'scales': scales,
                    'sid': sid
                }
        t1 = time.time()
        audio = np.squeeze(ort_sess.run(None, ort_inputs))
        audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
        audio = np.clip(audio, -32767.0, 32767.0)
        bytes_wav = bytes()
        byte_io = io.BytesIO(bytes_wav)
        wavfile.write(args.audio + '.wav',hps.data.sampling_rate, audio.astype(np.int16))
        i = 0
        while i < 19:
            i +=1
            cmd = 'ffmpeg -y -i ' +  args.audio + '.wav' + ' -ar 44100 '+ args.audio.replace('temp','temp'+str(i))
            os.system(cmd)
        t2 = time.time()
        print("推理耗时:",(t2 - t1),"s")
    return text
tokenizer = AutoTokenizer.from_pretrained(args.ChatGLM, trust_remote_code=True)
#8G GPU
model = AutoModel.from_pretrained(args.ChatGLM, trust_remote_code=True).half().quantize(4).cuda()
history = []
def send_message():
    global history
    message = input_box.get("1.0", "end-1c") # 获取用户输入的文本
    t1 = time.time()
    if message == 'clear':
      history = []
    else:
      response, new_history = model.chat(tokenizer, message, history)
      response = response.replace(" ",'').replace("\n",'.')
      text = infer(response)
      text = text.replace('[JA]','').replace('[ZH]','')
      chat_box.configure(state='normal') # 配置聊天框为可写状态
      chat_box.insert(tk.END, "You: " + message + "\n") # 在聊天框中显示用户输入的文本
      chat_box.insert(tk.END, "Tamao: " + text + "\n") # 在聊天框中显示 chatbot 的回复
    chat_box.configure(state='disabled') # 配置聊天框为只读状态
    input_box.delete("1.0", tk.END) # 清空输入框
    t2 = time.time()
    print("总共耗时:",(t2 - t1),"s")

root = tk.Tk()
root.title("Tamao")

# 创建聊天框
chat_box = scrolledtext.ScrolledText(root, width=50, height=10)
chat_box.configure(state='disabled') # 聊天框一开始是只读状态
chat_box.pack(side=tk.TOP, fill=tk.BOTH, padx=10, pady=10, expand=True)

# 创建输入框和发送按钮
input_frame = tk.Frame(root)
input_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=10)
input_box = tk.Text(input_frame, height=3, width=50) # 设置输入框宽度为50
input_box.pack(side=tk.LEFT, fill=tk.X, padx=10, expand=True)
send_button = tk.Button(input_frame, text="Send", command=send_message)
send_button.pack(side=tk.RIGHT, padx=10)

# 运行主程序
root.mainloop()