OpenSound commited on
Commit
a09029c
1 Parent(s): 5099022

Update src/inference.py

Browse files
Files changed (1) hide show
  1. src/inference.py +169 -169
src/inference.py CHANGED
@@ -1,169 +1,169 @@
1
- import os
2
- import random
3
- import pandas as pd
4
- import torch
5
- import librosa
6
- import numpy as np
7
- import soundfile as sf
8
- from tqdm import tqdm
9
- from utils import scale_shift_re
10
-
11
-
12
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
13
- """
14
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
15
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
16
- """
17
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
18
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
19
- # rescale the results from guidance (fixes overexposure)
20
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
21
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
22
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
23
- return noise_cfg
24
-
25
-
26
- @torch.no_grad()
27
- def inference(autoencoder, unet, gt, gt_mask,
28
- tokenizer, text_encoder,
29
- params, noise_scheduler,
30
- text_raw, neg_text=None,
31
- audio_frames=500,
32
- guidance_scale=3, guidance_rescale=0.0,
33
- ddim_steps=50, eta=1, random_seed=2024,
34
- device='cuda',
35
- ):
36
- if neg_text is None:
37
- neg_text = [""]
38
- if tokenizer is not None:
39
- text_batch = tokenizer(text_raw,
40
- max_length=params['text_encoder']['max_length'],
41
- padding="max_length", truncation=True, return_tensors="pt")
42
- text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
43
- text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
44
-
45
- uncond_text_batch = tokenizer(neg_text,
46
- max_length=params['text_encoder']['max_length'],
47
- padding="max_length", truncation=True, return_tensors="pt")
48
- uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
49
- uncond_text = text_encoder(input_ids=uncond_text,
50
- attention_mask=uncond_text_mask).last_hidden_state
51
- else:
52
- text, text_mask = None, None
53
- guidance_scale = None
54
-
55
- codec_dim = params['model']['out_chans']
56
- unet.eval()
57
-
58
- if random_seed is not None:
59
- generator = torch.Generator(device=device).manual_seed(random_seed)
60
- else:
61
- generator = torch.Generator(device=device)
62
- generator.seed()
63
-
64
- noise_scheduler.set_timesteps(ddim_steps)
65
-
66
- # init noise
67
- noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
68
- latents = noise
69
-
70
- for t in noise_scheduler.timesteps:
71
- latents = noise_scheduler.scale_model_input(latents, t)
72
-
73
- if guidance_scale:
74
-
75
- latents_combined = torch.cat([latents, latents], dim=0)
76
- text_combined = torch.cat([text, uncond_text], dim=0)
77
- text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
78
-
79
- if gt is not None:
80
- gt_combined = torch.cat([gt, gt], dim=0)
81
- gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
82
- else:
83
- gt_combined = None
84
- gt_mask_combined = None
85
-
86
- output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
87
- cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
88
- output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
89
-
90
- output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
91
- if guidance_rescale > 0.0:
92
- output_pred = rescale_noise_cfg(output_pred, output_text,
93
- guidance_rescale=guidance_rescale)
94
- else:
95
- output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
96
- cls_token=None, gt=gt, mae_mask_infer=gt_mask)
97
-
98
- latents = noise_scheduler.step(model_output=output_pred, timestep=t,
99
- sample=latents,
100
- eta=eta, generator=generator).prev_sample
101
-
102
- pred = scale_shift_re(latents, params['autoencoder']['scale'],
103
- params['autoencoder']['shift'])
104
- if gt is not None:
105
- pred[~gt_mask] = gt[~gt_mask]
106
- pred_wav = autoencoder(embedding=pred)
107
- return pred_wav
108
-
109
-
110
- @torch.no_grad()
111
- def eval_udit(autoencoder, unet,
112
- tokenizer, text_encoder,
113
- params, noise_scheduler,
114
- val_df, subset,
115
- audio_frames, mae=False,
116
- guidance_scale=3, guidance_rescale=0.0,
117
- ddim_steps=50, eta=1, random_seed=2023,
118
- device='cuda',
119
- epoch=0, save_path='logs/eval/', val_num=5):
120
- val_df = pd.read_csv(val_df)
121
- val_df = val_df[val_df['split'] == subset]
122
- if mae:
123
- val_df = val_df[val_df['audio_length'] != 0]
124
-
125
- save_path = save_path + str(epoch) + '/'
126
- os.makedirs(save_path, exist_ok=True)
127
-
128
- for i in tqdm(range(len(val_df))):
129
- row = val_df.iloc[i]
130
- text = [row['caption']]
131
- if mae:
132
- audio_path = params['data']['val_dir'] + str(row['audio_path'])
133
- gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
134
- gt = gt / (np.max(np.abs(gt)) + 1e-9)
135
- sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
136
- num_samples = 10 * sr
137
- if len(gt) < num_samples:
138
- padding = num_samples - len(gt)
139
- gt = np.pad(gt, (0, padding), 'constant')
140
- else:
141
- gt = gt[:num_samples]
142
- gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
143
- gt = autoencoder(audio=gt)
144
- B, D, L = gt.shape
145
- mask_len = int(L * 0.2)
146
- gt_mask = torch.zeros(B, D, L).to(device)
147
- for _ in range(2):
148
- start = random.randint(0, L - mask_len)
149
- gt_mask[:, :, start:start + mask_len] = 1
150
- gt_mask = gt_mask.bool()
151
- else:
152
- gt = None
153
- gt_mask = None
154
-
155
- pred = inference(autoencoder, unet, gt, gt_mask,
156
- tokenizer, text_encoder,
157
- params, noise_scheduler,
158
- text, neg_text=None,
159
- audio_frames=audio_frames,
160
- guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
161
- ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
162
- device=device)
163
-
164
- pred = pred.cpu().numpy().squeeze(0).squeeze(0)
165
-
166
- sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
167
-
168
- if i + 1 >= val_num:
169
- break
 
1
+ import os
2
+ import random
3
+ import pandas as pd
4
+ import torch
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile as sf
8
+ from tqdm import tqdm
9
+ from .utils import scale_shift_re
10
+
11
+
12
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
13
+ """
14
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
15
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
16
+ """
17
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
18
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
19
+ # rescale the results from guidance (fixes overexposure)
20
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
21
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
22
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
23
+ return noise_cfg
24
+
25
+
26
+ @torch.no_grad()
27
+ def inference(autoencoder, unet, gt, gt_mask,
28
+ tokenizer, text_encoder,
29
+ params, noise_scheduler,
30
+ text_raw, neg_text=None,
31
+ audio_frames=500,
32
+ guidance_scale=3, guidance_rescale=0.0,
33
+ ddim_steps=50, eta=1, random_seed=2024,
34
+ device='cuda',
35
+ ):
36
+ if neg_text is None:
37
+ neg_text = [""]
38
+ if tokenizer is not None:
39
+ text_batch = tokenizer(text_raw,
40
+ max_length=params['text_encoder']['max_length'],
41
+ padding="max_length", truncation=True, return_tensors="pt")
42
+ text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
43
+ text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
44
+
45
+ uncond_text_batch = tokenizer(neg_text,
46
+ max_length=params['text_encoder']['max_length'],
47
+ padding="max_length", truncation=True, return_tensors="pt")
48
+ uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
49
+ uncond_text = text_encoder(input_ids=uncond_text,
50
+ attention_mask=uncond_text_mask).last_hidden_state
51
+ else:
52
+ text, text_mask = None, None
53
+ guidance_scale = None
54
+
55
+ codec_dim = params['model']['out_chans']
56
+ unet.eval()
57
+
58
+ if random_seed is not None:
59
+ generator = torch.Generator(device=device).manual_seed(random_seed)
60
+ else:
61
+ generator = torch.Generator(device=device)
62
+ generator.seed()
63
+
64
+ noise_scheduler.set_timesteps(ddim_steps)
65
+
66
+ # init noise
67
+ noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
68
+ latents = noise
69
+
70
+ for t in noise_scheduler.timesteps:
71
+ latents = noise_scheduler.scale_model_input(latents, t)
72
+
73
+ if guidance_scale:
74
+
75
+ latents_combined = torch.cat([latents, latents], dim=0)
76
+ text_combined = torch.cat([text, uncond_text], dim=0)
77
+ text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
78
+
79
+ if gt is not None:
80
+ gt_combined = torch.cat([gt, gt], dim=0)
81
+ gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
82
+ else:
83
+ gt_combined = None
84
+ gt_mask_combined = None
85
+
86
+ output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
87
+ cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
88
+ output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
89
+
90
+ output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
91
+ if guidance_rescale > 0.0:
92
+ output_pred = rescale_noise_cfg(output_pred, output_text,
93
+ guidance_rescale=guidance_rescale)
94
+ else:
95
+ output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
96
+ cls_token=None, gt=gt, mae_mask_infer=gt_mask)
97
+
98
+ latents = noise_scheduler.step(model_output=output_pred, timestep=t,
99
+ sample=latents,
100
+ eta=eta, generator=generator).prev_sample
101
+
102
+ pred = scale_shift_re(latents, params['autoencoder']['scale'],
103
+ params['autoencoder']['shift'])
104
+ if gt is not None:
105
+ pred[~gt_mask] = gt[~gt_mask]
106
+ pred_wav = autoencoder(embedding=pred)
107
+ return pred_wav
108
+
109
+
110
+ @torch.no_grad()
111
+ def eval_udit(autoencoder, unet,
112
+ tokenizer, text_encoder,
113
+ params, noise_scheduler,
114
+ val_df, subset,
115
+ audio_frames, mae=False,
116
+ guidance_scale=3, guidance_rescale=0.0,
117
+ ddim_steps=50, eta=1, random_seed=2023,
118
+ device='cuda',
119
+ epoch=0, save_path='logs/eval/', val_num=5):
120
+ val_df = pd.read_csv(val_df)
121
+ val_df = val_df[val_df['split'] == subset]
122
+ if mae:
123
+ val_df = val_df[val_df['audio_length'] != 0]
124
+
125
+ save_path = save_path + str(epoch) + '/'
126
+ os.makedirs(save_path, exist_ok=True)
127
+
128
+ for i in tqdm(range(len(val_df))):
129
+ row = val_df.iloc[i]
130
+ text = [row['caption']]
131
+ if mae:
132
+ audio_path = params['data']['val_dir'] + str(row['audio_path'])
133
+ gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
134
+ gt = gt / (np.max(np.abs(gt)) + 1e-9)
135
+ sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
136
+ num_samples = 10 * sr
137
+ if len(gt) < num_samples:
138
+ padding = num_samples - len(gt)
139
+ gt = np.pad(gt, (0, padding), 'constant')
140
+ else:
141
+ gt = gt[:num_samples]
142
+ gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
143
+ gt = autoencoder(audio=gt)
144
+ B, D, L = gt.shape
145
+ mask_len = int(L * 0.2)
146
+ gt_mask = torch.zeros(B, D, L).to(device)
147
+ for _ in range(2):
148
+ start = random.randint(0, L - mask_len)
149
+ gt_mask[:, :, start:start + mask_len] = 1
150
+ gt_mask = gt_mask.bool()
151
+ else:
152
+ gt = None
153
+ gt_mask = None
154
+
155
+ pred = inference(autoencoder, unet, gt, gt_mask,
156
+ tokenizer, text_encoder,
157
+ params, noise_scheduler,
158
+ text, neg_text=None,
159
+ audio_frames=audio_frames,
160
+ guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
161
+ ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
162
+ device=device)
163
+
164
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
165
+
166
+ sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
167
+
168
+ if i + 1 >= val_num:
169
+ break