nickovchinnikov commited on
Commit
bbd9e13
·
1 Parent(s): 52e3665

Fix missed file

Browse files
Files changed (1) hide show
  1. demo/delightful_univnet.py +74 -0
demo/delightful_univnet.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+
3
+ from gradio import Checkbox, Dropdown, Interface, Textbox
4
+ import soundfile as sf
5
+ import torch
6
+ from voicefixer import VoiceFixer
7
+
8
+ from models.delightful_univnet import DelightfulUnivnet
9
+ from training.datasets.hifi_libri_dataset import speakers_hifi_ids
10
+
11
+ from .config import speakers_delightful_22050
12
+
13
+ delightful_checkpoint_path = "epoch=5816-step=390418.ckpt"
14
+
15
+ device = torch.device("cpu")
16
+
17
+ delightfulunivnet_22050 = DelightfulUnivnet(
18
+ delightful_checkpoint_path=delightful_checkpoint_path,
19
+ ).to(device)
20
+
21
+ voicefixer = VoiceFixer()
22
+
23
+
24
+ def generate_audio(text: str, speaker_name: str, fix_voice: bool):
25
+ speaker = torch.tensor(
26
+ [speakers_delightful_22050[speaker_name]],
27
+ device=device,
28
+ )
29
+ with torch.no_grad():
30
+ wav = delightfulunivnet_22050.forward(text, speaker)
31
+ wav = wav.squeeze().detach().cpu().numpy()
32
+
33
+ if fix_voice:
34
+ # Save the numpy array to a temporary wav file
35
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as input_file:
36
+ # Write to the temp wav file
37
+ sf.write(input_file.name, wav, delightfulunivnet_22050.sampling_rate)
38
+
39
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as output_file:
40
+ voicefixer.restore(
41
+ input=input_file.name, # low quality .wav/.flac file
42
+ output=output_file.name, # save file path
43
+ cuda=False, # GPU acceleration off
44
+ mode=0,
45
+ )
46
+
47
+ # Read the wav file back into a numpy array
48
+ wav_vf, sampling_rate = sf.read(output_file.name)
49
+
50
+ return sampling_rate, wav_vf
51
+
52
+ return delightfulunivnet_22050.sampling_rate, wav
53
+
54
+
55
+ interfaceDelightfulUnuvnet22050 = Interface(
56
+ generate_audio,
57
+ [
58
+ Textbox(
59
+ label="Text",
60
+ value="As the snake shook its head, a deafening shout behind Harry made both of them jump.",
61
+ ),
62
+ Dropdown(
63
+ label="Speaker",
64
+ choices=list(speakers_delightful_22050.keys()),
65
+ value=speakers_hifi_ids[0],
66
+ ),
67
+ Checkbox(
68
+ label="Fix voice (Voicefixer)",
69
+ value=False,
70
+ ),
71
+ ],
72
+ outputs="audio",
73
+ title=f"Delightful UnivNet, Sampling Rate: {delightfulunivnet_22050.sampling_rate}. When Voicefixer is enabled, the Simpling Rate is 44100.",
74
+ )