Plachta commited on
Commit
300a0b5
1 Parent(s): 2fb913a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -140
app.py CHANGED
@@ -1,141 +1,141 @@
1
- import spaces
2
- import gradio as gr
3
- import torch
4
- import torchaudio
5
- import librosa
6
- from modules.commons import build_model, load_checkpoint, recursive_munch
7
- import yaml
8
- from hf_utils import load_custom_model_from_hf
9
-
10
- # Load model and configuration
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
14
- "DiT_step_315000_seed_v2_online_pruned.pth",
15
- "config_dit_mel_seed.yml")
16
-
17
- config = yaml.safe_load(open(dit_config_path, 'r'))
18
- model_params = recursive_munch(config['model_params'])
19
- model = build_model(model_params, stage='DiT')
20
- hop_length = config['preprocess_params']['spect_params']['hop_length']
21
- sr = config['preprocess_params']['sr']
22
-
23
- # Load checkpoints
24
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
25
- load_only_params=True, ignore_modules=[], is_distributed=False)
26
- for key in model:
27
- model[key].eval()
28
- model[key].to(device)
29
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
30
-
31
- # Load additional modules
32
- from modules.campplus.DTDNN import CAMPPlus
33
-
34
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
35
- campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path']))
36
- campplus_model.eval()
37
- campplus_model.to(device)
38
-
39
- from modules.hifigan.generator import HiFTGenerator
40
- from modules.hifigan.f0_predictor import ConvRNNF0Predictor
41
-
42
- hift_checkpoint_path, hift_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
43
- "hift.pt",
44
- "hifigan.yml")
45
- hift_config = yaml.safe_load(open(hift_config_path, 'r'))
46
- hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
47
- hift_gen.load_state_dict(torch.load(hift_config['pretrained_model_path'], map_location='cpu'))
48
- hift_gen.eval()
49
- hift_gen.to(device)
50
-
51
- from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
52
-
53
- speech_tokenizer_path = load_custom_model_from_hf("Plachta/Seed-VC", "speech_tokenizer_v1.onnx", None)
54
-
55
- cosyvoice_frontend = CosyVoiceFrontEnd(speech_tokenizer_model=speech_tokenizer_path,
56
- device='cuda', device_id=0)
57
- # Generate mel spectrograms
58
- mel_fn_args = {
59
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
60
- "win_size": config['preprocess_params']['spect_params']['win_length'],
61
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
62
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
63
- "sampling_rate": sr,
64
- "fmin": 0,
65
- "fmax": 8000,
66
- "center": False
67
- }
68
- from modules.audio import mel_spectrogram
69
-
70
- to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
71
-
72
- @spaces.GPU
73
- @torch.no_grad()
74
- @torch.inference_mode()
75
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate):
76
- # Load audio
77
- source_audio = librosa.load(source, sr=sr)[0]
78
- ref_audio = librosa.load(target, sr=sr)[0]
79
-
80
- # Process audio
81
- source_audio = torch.tensor(source_audio[:sr * 30]).unsqueeze(0).float().to(device)
82
- ref_audio = torch.tensor(ref_audio[:sr * 30]).unsqueeze(0).float().to(device)
83
-
84
- # Resample
85
- source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
86
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
87
-
88
- # Extract features
89
- S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
90
- S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
91
-
92
- mel = to_mel(source_audio.to(device).float())
93
- mel2 = to_mel(ref_audio.to(device).float())
94
-
95
- target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
96
- target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
97
-
98
- # Style encoding
99
- feat = torchaudio.compliance.kaldi.fbank(source_waves_16k,
100
- num_mel_bins=80,
101
- dither=0,
102
- sample_frequency=16000)
103
- feat = feat - feat.mean(dim=0, keepdim=True)
104
- style1 = campplus_model(feat.unsqueeze(0))
105
-
106
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
107
- num_mel_bins=80,
108
- dither=0,
109
- sample_frequency=16000)
110
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
111
- style2 = campplus_model(feat2.unsqueeze(0))
112
-
113
- # Length regulation
114
- cond = model.length_regulator(S_alt, ylens=target_lengths)[0]
115
- prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0]
116
- cat_condition = torch.cat([prompt_condition, cond], dim=1)
117
-
118
- # Voice Conversion
119
- vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
120
- mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
121
- vc_target = vc_target[:, :, mel2.size(-1):]
122
-
123
- # Convert to waveform
124
- vc_wave = hift_gen.inference(vc_target)
125
-
126
- return (sr, vc_wave.squeeze(0).cpu().numpy())
127
-
128
-
129
- if __name__ == "__main__":
130
- description = "Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) for details and updates."
131
- inputs = [
132
- gr.Audio(source="upload", type="filepath", label="Source Audio"),
133
- gr.Audio(source="upload", type="filepath", label="Reference Audio"),
134
- gr.Slider(minimum=1, maximum=1000, value=100, step=1, label="Diffusion Steps"),
135
- gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust"),
136
- gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate"),
137
- ]
138
-
139
- outputs = gr.Audio(label="Output Audio")
140
-
141
  gr.Interface(fn=voice_conversion, description=description, inputs=inputs, outputs=outputs, title="Seed Voice Conversion").launch()
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+ import librosa
6
+ from modules.commons import build_model, load_checkpoint, recursive_munch
7
+ import yaml
8
+ from hf_utils import load_custom_model_from_hf
9
+
10
+ # Load model and configuration
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
14
+ "DiT_step_315000_seed_v2_online_pruned.pth",
15
+ "config_dit_mel_seed.yml")
16
+
17
+ config = yaml.safe_load(open(dit_config_path, 'r'))
18
+ model_params = recursive_munch(config['model_params'])
19
+ model = build_model(model_params, stage='DiT')
20
+ hop_length = config['preprocess_params']['spect_params']['hop_length']
21
+ sr = config['preprocess_params']['sr']
22
+
23
+ # Load checkpoints
24
+ model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
25
+ load_only_params=True, ignore_modules=[], is_distributed=False)
26
+ for key in model:
27
+ model[key].eval()
28
+ model[key].to(device)
29
+ model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
30
+
31
+ # Load additional modules
32
+ from modules.campplus.DTDNN import CAMPPlus
33
+
34
+ campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
35
+ campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path']))
36
+ campplus_model.eval()
37
+ campplus_model.to(device)
38
+
39
+ from modules.hifigan.generator import HiFTGenerator
40
+ from modules.hifigan.f0_predictor import ConvRNNF0Predictor
41
+
42
+ hift_checkpoint_path, hift_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
43
+ "hift.pt",
44
+ "hifigan.yml")
45
+ hift_config = yaml.safe_load(open(hift_config_path, 'r'))
46
+ hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
47
+ hift_gen.load_state_dict(torch.load(hift_checkpoint_path, map_location='cpu'))
48
+ hift_gen.eval()
49
+ hift_gen.to(device)
50
+
51
+ from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
52
+
53
+ speech_tokenizer_path = load_custom_model_from_hf("Plachta/Seed-VC", "speech_tokenizer_v1.onnx", None)
54
+
55
+ cosyvoice_frontend = CosyVoiceFrontEnd(speech_tokenizer_model=speech_tokenizer_path,
56
+ device='cuda', device_id=0)
57
+ # Generate mel spectrograms
58
+ mel_fn_args = {
59
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
60
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
61
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
62
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
63
+ "sampling_rate": sr,
64
+ "fmin": 0,
65
+ "fmax": 8000,
66
+ "center": False
67
+ }
68
+ from modules.audio import mel_spectrogram
69
+
70
+ to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
71
+
72
+ @spaces.GPU
73
+ @torch.no_grad()
74
+ @torch.inference_mode()
75
+ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate):
76
+ # Load audio
77
+ source_audio = librosa.load(source, sr=sr)[0]
78
+ ref_audio = librosa.load(target, sr=sr)[0]
79
+
80
+ # Process audio
81
+ source_audio = torch.tensor(source_audio[:sr * 30]).unsqueeze(0).float().to(device)
82
+ ref_audio = torch.tensor(ref_audio[:sr * 30]).unsqueeze(0).float().to(device)
83
+
84
+ # Resample
85
+ source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
86
+ ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
87
+
88
+ # Extract features
89
+ S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
90
+ S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
91
+
92
+ mel = to_mel(source_audio.to(device).float())
93
+ mel2 = to_mel(ref_audio.to(device).float())
94
+
95
+ target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
96
+ target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
97
+
98
+ # Style encoding
99
+ feat = torchaudio.compliance.kaldi.fbank(source_waves_16k,
100
+ num_mel_bins=80,
101
+ dither=0,
102
+ sample_frequency=16000)
103
+ feat = feat - feat.mean(dim=0, keepdim=True)
104
+ style1 = campplus_model(feat.unsqueeze(0))
105
+
106
+ feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
107
+ num_mel_bins=80,
108
+ dither=0,
109
+ sample_frequency=16000)
110
+ feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
111
+ style2 = campplus_model(feat2.unsqueeze(0))
112
+
113
+ # Length regulation
114
+ cond = model.length_regulator(S_alt, ylens=target_lengths)[0]
115
+ prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0]
116
+ cat_condition = torch.cat([prompt_condition, cond], dim=1)
117
+
118
+ # Voice Conversion
119
+ vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
120
+ mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
121
+ vc_target = vc_target[:, :, mel2.size(-1):]
122
+
123
+ # Convert to waveform
124
+ vc_wave = hift_gen.inference(vc_target)
125
+
126
+ return (sr, vc_wave.squeeze(0).cpu().numpy())
127
+
128
+
129
+ if __name__ == "__main__":
130
+ description = "Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) for details and updates."
131
+ inputs = [
132
+ gr.Audio(source="upload", type="filepath", label="Source Audio"),
133
+ gr.Audio(source="upload", type="filepath", label="Reference Audio"),
134
+ gr.Slider(minimum=1, maximum=1000, value=100, step=1, label="Diffusion Steps"),
135
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust"),
136
+ gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate"),
137
+ ]
138
+
139
+ outputs = gr.Audio(label="Output Audio")
140
+
141
  gr.Interface(fn=voice_conversion, description=description, inputs=inputs, outputs=outputs, title="Seed Voice Conversion").launch()