File size: 6,415 Bytes
0a97d6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import yaml
import torch
from diffusers import DDIMScheduler
from .model.model import DiffVC
from .model.model_cross import DiffVC_Cross
from .utils import scale_shift, scale_shift_re, rescale_noise_cfg
class ReDiffVC(object):
def __init__(self,
config_path='configs/diffvc_base.yaml',
ckpt_path='../ckpts/dreamvc_base.pt',
device='cpu'):
with open(config_path, 'r') as fp:
config = yaml.safe_load(fp)
self.device = device
self.model = DiffVC(config['model']).to(device)
self.model.load_state_dict(torch.load(ckpt_path)['model'])
self.model.eval()
noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
beta_start=config['scheduler']['beta_start'],
beta_end=config['scheduler']['beta_end'],
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
clip_sample=False,
prediction_type='v_prediction')
self.noise_scheduler = noise_scheduler
self.scale = config['scheduler']['scale']
self.shift = config['scheduler']['shift']
self.melshape = config['model']['unet']['sample_size'][0]
@torch.no_grad()
def inference(self,
spk_embed, content_clip, f0_clip=None,
guidance_scale=3, guidance_rescale=0.7,
ddim_steps=50, eta=1, random_seed=2023):
self.model.eval()
if random_seed is not None:
generator = torch.Generator(device=self.device).manual_seed(random_seed)
else:
generator = torch.Generator(device=self.device)
generator.seed()
self.noise_scheduler.set_timesteps(ddim_steps)
# init noise
gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
noise = torch.randn(gen_shape, generator=generator, device=self.device)
latents = noise
for t in self.noise_scheduler.timesteps:
latents = self.noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
output_text = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)
output_uncond = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=True,
speaker_cfg=1.0, pitch_cfg=0.0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
output_pred = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)
latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
return pred
class DreamVC(object):
def __init__(self,
config_path='configs/diffvc_cross.yaml',
ckpt_path='../ckpts/dreamvc_cross.pt',
device='cpu'):
with open(config_path, 'r') as fp:
config = yaml.safe_load(fp)
self.device = device
self.model = DiffVC_Cross(config['model']).to(device)
self.model.load_state_dict(torch.load(ckpt_path)['model'])
self.model.eval()
noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
beta_start=config['scheduler']['beta_start'],
beta_end=config['scheduler']['beta_end'],
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
clip_sample=False,
prediction_type='v_prediction')
self.noise_scheduler = noise_scheduler
self.scale = config['scheduler']['scale']
self.shift = config['scheduler']['shift']
self.melshape = config['model']['unet']['sample_size'][0]
@torch.no_grad()
def inference(self,
text, content_clip, f0_clip=None,
guidance_scale=3, guidance_rescale=0.7,
ddim_steps=50, eta=1, random_seed=2023):
text, text_mask = text
self.model.eval()
if random_seed is not None:
generator = torch.Generator(device=self.device).manual_seed(random_seed)
else:
generator = torch.Generator(device=self.device)
generator.seed()
self.noise_scheduler.set_timesteps(ddim_steps)
# init noise
gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
noise = torch.randn(gen_shape, generator=generator, device=self.device)
latents = noise
for t in self.noise_scheduler.timesteps:
latents = self.noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
output_text = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)
output_uncond = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=True,
speaker_cfg=1.0, pitch_cfg=0.0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
output_pred = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)
latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
return pred
|