|
import yaml |
|
import torch |
|
from diffusers import DDIMScheduler |
|
from .model.p2e_cross import P2E_Cross |
|
from .utils import scale_shift, scale_shift_re, rescale_noise_cfg |
|
|
|
|
|
class DreamVG(object): |
|
def __init__(self, |
|
config_path='configs/plugin_cross.yaml', |
|
ckpt_path='../ckpts/dreamvc_plugin.pt', |
|
device='cpu'): |
|
|
|
with open(config_path, 'r') as fp: |
|
config = yaml.safe_load(fp) |
|
|
|
self.device = device |
|
self.model = P2E_Cross(config['model']).to(device) |
|
self.model.load_state_dict(torch.load(ckpt_path)['model']) |
|
self.model.eval() |
|
|
|
noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'], |
|
beta_start=config['scheduler']['beta_start'], |
|
beta_end=config['scheduler']['beta_end'], |
|
rescale_betas_zero_snr=True, |
|
timestep_spacing="trailing", |
|
clip_sample=False, |
|
prediction_type='v_prediction') |
|
self.noise_scheduler = noise_scheduler |
|
self.scale = config['scheduler']['scale'] |
|
self.shift = config['scheduler']['shift'] |
|
self.spk_shape = config['model']['unet']['in_channels'] |
|
|
|
@torch.no_grad() |
|
def inference(self, text, |
|
guidance_scale=5, guidance_rescale=0.7, |
|
ddim_steps=50, eta=1, random_seed=2023, |
|
): |
|
text, text_mask = text |
|
self.model.eval() |
|
|
|
gen_shape = (1, self.spk_shape) |
|
|
|
if random_seed is not None: |
|
generator = torch.Generator(device=self.device).manual_seed(random_seed) |
|
else: |
|
generator = torch.Generator(device=self.device) |
|
generator.seed() |
|
|
|
self.noise_scheduler.set_timesteps(ddim_steps) |
|
|
|
|
|
noise = torch.randn(gen_shape, generator=generator, device=self.device) |
|
latents = noise |
|
|
|
for t in self.noise_scheduler.timesteps: |
|
latents = self.noise_scheduler.scale_model_input(latents, t) |
|
|
|
if guidance_scale: |
|
output_text = self.model(latents, t, text, text_mask, train_cfg=False) |
|
output_uncond = self.model(latents, t, text, text_mask, train_cfg=True, cfg_prob=1.0) |
|
|
|
output_pred = output_uncond + guidance_scale * (output_text - output_uncond) |
|
if guidance_rescale > 0.0: |
|
output_pred = rescale_noise_cfg(output_pred, output_text, |
|
guidance_rescale=guidance_rescale) |
|
else: |
|
output_pred = self.model(latents, t, text, text_mask, train_cfg=False) |
|
|
|
latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents, |
|
eta=eta, generator=generator).prev_sample |
|
|
|
|
|
pred = scale_shift_re(latents, 1/self.scale, self.shift) |
|
|
|
return pred |