Spaces:
Runtime error
Runtime error
"""SAMPLING ONLY.""" | |
import torch | |
import ptp_scripts.ptp_scripts as ptp | |
import ptp_scripts.ptp_utils as ptp_utils | |
# from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver | |
from scripts.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver | |
from tqdm import tqdm | |
MODEL_TYPES = { | |
"eps": "noise", | |
"v": "v" | |
} | |
class DPMSolverSampler(object): | |
def __init__(self, model, **kwargs): | |
super().__init__() | |
self.model = model | |
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) | |
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) | |
def register_buffer(self, name, attr): | |
if type(attr) == torch.Tensor: | |
if attr.device != self.model.device: | |
attr = attr.to(self.model.device) | |
setattr(self, name, attr) | |
def sample(self, | |
steps, | |
batch_size, | |
shape, | |
conditioning=None, | |
inv_emb=None, | |
callback=None, | |
normals_sequence=None, | |
img_callback=None, | |
quantize_x0=False, | |
eta=0., | |
mask=None, | |
x0=None, | |
temperature=1., | |
noise_dropout=0., | |
score_corrector=None, | |
corrector_kwargs=None, | |
verbose=True, | |
x_T=None, | |
log_every_t=100, | |
unconditional_guidance_scale=1., | |
unconditional_conditioning=None, | |
t_start=None, | |
t_end=None, | |
DPMencode=False, | |
order=3, | |
width=None, | |
height=None, | |
ref=False, | |
top=None, | |
left=None, | |
bottom=None, | |
right=None, | |
segmentation_map=None, | |
param=None, | |
target_height=None, | |
target_width=None, | |
center_row_rm=None, | |
center_col_rm=None, | |
tau_a=0.4, | |
tau_b=0.8, | |
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
**kwargs | |
): | |
if conditioning is not None: | |
if isinstance(conditioning, dict): | |
cbs = conditioning[list(conditioning.keys())[0]].shape[0] | |
if cbs != batch_size: | |
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") | |
else: | |
if conditioning.shape[0] != batch_size: | |
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") | |
# sampling | |
C, H, W = shape | |
size = (batch_size, C, H, W) | |
# print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {steps}') | |
device = self.model.betas.device | |
if x_T is None: | |
x = torch.randn(size, device=device) | |
else: | |
x = x_T | |
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) | |
if DPMencode: | |
# x_T is not a list | |
model_fn = model_wrapper( | |
lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=None, inject=inject), | |
ns, | |
model_type=MODEL_TYPES[self.model.parameterization], | |
guidance_type="classifier-free", | |
condition=inv_emb, | |
unconditional_condition=inv_emb, | |
guidance_scale=unconditional_guidance_scale, | |
) | |
dpm_solver = DPM_Solver(model_fn, ns) | |
data, _ = self.low_order_sample(x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=DPMencode) | |
for step in range(order, steps + 1): | |
data = dpm_solver.sample_one_step(data, step, steps, order=order, DPMencode=DPMencode) | |
return data['x'].to(device), None | |
else: | |
# x_T is a list | |
model_fn_decode = model_wrapper( | |
lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject), | |
ns, | |
model_type=MODEL_TYPES[self.model.parameterization], | |
guidance_type="classifier-free", | |
condition=inv_emb, | |
unconditional_condition=inv_emb, | |
guidance_scale=unconditional_guidance_scale, | |
) | |
model_fn_gen = model_wrapper( | |
lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject), | |
ns, | |
model_type=MODEL_TYPES[self.model.parameterization], | |
guidance_type="classifier-free", | |
condition=conditioning, | |
unconditional_condition=unconditional_conditioning, | |
guidance_scale=unconditional_guidance_scale, | |
) | |
orig_controller = ptp.AttentionStore() | |
ref_controller = ptp.AttentionStore() | |
cross_controller = ptp.AttentionStore() | |
gen_controller = ptp.AttentionStore() | |
Inject_controller = ptp.AttentionStore() | |
dpm_solver_decode = DPM_Solver(model_fn_decode, ns) | |
dpm_solver_gen = DPM_Solver(model_fn_gen, ns) | |
# decoded background | |
ptp_utils.register_attention_control(self.model, orig_controller, center_row_rm, center_col_rm, target_height, target_width, | |
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone()) | |
orig, orig_controller = self.low_order_sample(x[0], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=orig_controller) | |
# decoded reference | |
ptp_utils.register_attention_control(self.model, ref_controller, center_row_rm, center_col_rm, target_height, target_width, | |
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone()) | |
ref, ref_controller = self.low_order_sample(x[3], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=ref_controller) | |
# decode for cross-attention | |
ptp_utils.register_attention_control(self.model, cross_controller, center_row_rm, center_col_rm, target_height, target_width, | |
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), pseudo_cross=True) | |
cross, cross_controller = self.low_order_sample(x[2], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, | |
controller=cross_controller, ref_init=ref['x'].clone()) | |
# generation | |
Inject_controller = [orig_controller, ref_controller, cross_controller] | |
ptp_utils.register_attention_control(self.model, gen_controller, center_row_rm, center_col_rm, target_height, target_width, | |
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), inject_bg=True) | |
gen, _ = self.low_order_sample(x[4], dpm_solver_gen, steps, order, t_start, t_end, device, | |
DPMencode=DPMencode, controller=Inject_controller, inject=True) | |
for i in range(len(orig['model_prev_list'])): | |
blended = orig['model_prev_list'][i].clone() | |
blended[:, :, param[0] : param[1], param[2] : param[3]] \ | |
= gen['model_prev_list'][i][:, :, param[0] : param[1], param[2] : param[3]].clone() | |
gen['model_prev_list'][i] = blended.clone() | |
del orig_controller, ref_controller, cross_controller, gen_controller, Inject_controller | |
orig_controller = ptp.AttentionStore() | |
ref_controller = ptp.AttentionStore() | |
cross_controller = ptp.AttentionStore() | |
gen_controller = ptp.AttentionStore() | |
for step in range(order, steps + 1): | |
# decoded background | |
ptp_utils.register_attention_control(self.model, orig_controller, center_row_rm, center_col_rm, target_height, target_width, | |
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone()) | |
orig = dpm_solver_decode.sample_one_step(orig, step, steps, order=order, DPMencode=DPMencode) | |
# decode for cross-attention | |
ptp_utils.register_attention_control(self.model, cross_controller, center_row_rm, center_col_rm, target_height, target_width, | |
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), pseudo_cross=True) | |
cross['x'] = orig['x'] | |
cross = dpm_solver_decode.sample_one_step(cross, step, steps, order=order, DPMencode=DPMencode, ref_init=ref['x'].clone()) | |
if step < int(tau_a*(steps) + 1 - order): | |
inject = True | |
# decoded reference | |
ptp_utils.register_attention_control(self.model, ref_controller, center_row_rm, center_col_rm, target_height, target_width, | |
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone()) | |
ref = dpm_solver_decode.sample_one_step(ref, step, steps, order=order, DPMencode=DPMencode) | |
controller = [orig_controller, ref_controller, cross_controller] | |
else: | |
inject = False | |
controller = [orig_controller, None, cross_controller] | |
if step < int(0.4*(steps) + 1 - order): | |
inject_bg = True | |
else: | |
inject_bg = False | |
# generation | |
ptp_utils.register_attention_control(self.model, gen_controller, center_row_rm, center_col_rm, target_height, target_width, width, height, | |
top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), inject_bg=inject_bg) | |
gen = dpm_solver_gen.sample_one_step(gen, step, steps, order=order, DPMencode=DPMencode, controller=controller, inject=inject) | |
if step < int(tau_b*(steps) + 1 - order): | |
blended = orig['x'].clone() | |
blended[:, :, param[0] : param[1], param[2] : param[3]] \ | |
= gen['x'][:, :, param[0] : param[1], param[2] : param[3]].clone() | |
gen['x'] = blended.clone() | |
del orig_controller, ref_controller, cross_controller, gen_controller, controller | |
return gen['x'].to(device), None | |
def low_order_sample(self, x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=False, controller=None, inject=False, ref_init=None): | |
t_0 = 1. / dpm_solver.noise_schedule.total_N if t_end is None else t_end | |
t_T = dpm_solver.noise_schedule.T if t_start is None else t_start | |
total_controller = [] | |
assert steps >= order | |
timesteps = dpm_solver.get_time_steps(skip_type="time_uniform", t_T=t_T, t_0=t_0, N=steps, device=device, DPMencode=DPMencode) | |
assert timesteps.shape[0] - 1 == steps | |
with torch.no_grad(): | |
vec_t = timesteps[0].expand((x.shape[0])) | |
model_prev_list = [dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode, | |
controller=[controller[0][0], controller[1][0], controller[2][0]] if isinstance(controller, list) else controller, | |
inject=inject, ref_init=ref_init)] | |
total_controller.append(controller) | |
t_prev_list = [vec_t] | |
# Init the first `order` values by lower order multistep DPM-Solver. | |
for init_order in range(1, order): | |
vec_t = timesteps[init_order].expand(x.shape[0]) | |
x = dpm_solver.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, | |
solver_type='dpmsolver', DPMencode=DPMencode) | |
model_prev_list.append(dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode, | |
controller=[controller[0][init_order], controller[1][init_order], controller[2][init_order]] if isinstance(controller, list) else controller, | |
inject=inject, ref_init=ref_init)) | |
total_controller.append(controller) | |
t_prev_list.append(vec_t) | |
return {'x': x, 'model_prev_list': model_prev_list, 't_prev_list': t_prev_list, 'timesteps':timesteps}, total_controller | |