|
from functools import partial |
|
|
|
import torch |
|
|
|
from opensora.registry import SCHEDULERS |
|
|
|
from .dpm_solver import DPMS |
|
|
|
|
|
@SCHEDULERS.register_module("dpm-solver") |
|
class DMP_SOLVER: |
|
def __init__(self, num_sampling_steps=None, cfg_scale=4.0): |
|
self.num_sampling_steps = num_sampling_steps |
|
self.cfg_scale = cfg_scale |
|
|
|
def sample( |
|
self, |
|
model, |
|
text_encoder, |
|
z_size, |
|
prompts, |
|
device, |
|
additional_args=None, |
|
): |
|
n = len(prompts) |
|
z = torch.randn(n, *z_size, device=device) |
|
model_args = text_encoder.encode(prompts) |
|
y = model_args.pop("y") |
|
null_y = text_encoder.null(n) |
|
if additional_args is not None: |
|
model_args.update(additional_args) |
|
|
|
dpms = DPMS( |
|
partial(forward_with_dpmsolver, model), |
|
condition=y, |
|
uncondition=null_y, |
|
cfg_scale=self.cfg_scale, |
|
model_kwargs=model_args, |
|
) |
|
samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep") |
|
return samples |
|
|
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, **kwargs) |
|
return model_out.chunk(2, dim=1)[0] |
|
|