File size: 4,177 Bytes
26a53be
06220ce
 
26a53be
 
 
 
06220ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26a53be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from huggingface_hub import hf_hub_download


"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""


import os
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from collections import OrderedDict
from AutoPST.onmt_modules.misc import sequence_mask
from AutoPST.model_autopst import Generator_2 as Predictor
from AutoPST.hparams_autopst import hparams

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

P = Predictor(hparams).eval().to(device)

checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename='580000-P.ckpt'), map_location=lambda storage, loc: storage)  
P.load_state_dict(checkpoint['model'], strict=True)
print('Loaded predictor .....................................................')

dict_test = pickle.load(open('./AutoPST/assets/test_vctk.meta', 'rb'))

spect_vc = OrderedDict()

uttrs = [('p231', 'p270', '001'),
         ('p270', 'p231', '001'),
         ('p231', 'p245', '003001'),
         ('p245', 'p231', '003001'),
         ('p239', 'p270', '024002'),
         ('p270', 'p239', '024002')]


for uttr in uttrs:
        
    cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]
    cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)
    len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)
    real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()
    
    _, spk_emb = dict_test[uttr[1]][uttr[2]]
    spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
    
    with torch.no_grad():
        spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],
                                               real_mask_A,
                                               len_real_A,
                                               spk_emb_B)
    
    uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()
        
    spect_vc[f'{uttr[0]}_{uttr[1]}_{uttr[2]}'] = uttr_tgt

# spectrogram to waveform
# Feel free to use other vocoders
# This cell requires some preparation to work, please see the corresponding part in AutoVC
import torch
import librosa
import pickle
import os
from AutoPST.synthesis import build_model
from AutoPST.synthesis import wavegen

model = build_model().to(device)
checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])

# for name, sp in spect_vc.items():

#     print(name)
#     waveform = wavegen(model, c=sp)   

#     librosa.output.write_wav('./assets/'+name+'.wav', waveform, sr=16000)




def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()