sky24h's picture
init commit
910b9ab
"""SAMPLING ONLY."""
import torch
import ptp_scripts.ptp_scripts as ptp
import ptp_scripts.ptp_utils as ptp_utils
# from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
from scripts.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
from tqdm import tqdm
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.model.device:
attr = attr.to(self.model.device)
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
steps,
batch_size,
shape,
conditioning=None,
inv_emb=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
t_start=None,
t_end=None,
DPMencode=False,
order=3,
width=None,
height=None,
ref=False,
top=None,
left=None,
bottom=None,
right=None,
segmentation_map=None,
param=None,
target_height=None,
target_width=None,
center_row_rm=None,
center_col_rm=None,
tau_a=0.4,
tau_b=0.8,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {steps}')
device = self.model.betas.device
if x_T is None:
x = torch.randn(size, device=device)
else:
x = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
if DPMencode:
# x_T is not a list
model_fn = model_wrapper(
lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=None, inject=inject),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=inv_emb,
unconditional_condition=inv_emb,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns)
data, _ = self.low_order_sample(x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=DPMencode)
for step in range(order, steps + 1):
data = dpm_solver.sample_one_step(data, step, steps, order=order, DPMencode=DPMencode)
return data['x'].to(device), None
else:
# x_T is a list
model_fn_decode = model_wrapper(
lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=inv_emb,
unconditional_condition=inv_emb,
guidance_scale=unconditional_guidance_scale,
)
model_fn_gen = model_wrapper(
lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
orig_controller = ptp.AttentionStore()
ref_controller = ptp.AttentionStore()
cross_controller = ptp.AttentionStore()
gen_controller = ptp.AttentionStore()
Inject_controller = ptp.AttentionStore()
dpm_solver_decode = DPM_Solver(model_fn_decode, ns)
dpm_solver_gen = DPM_Solver(model_fn_gen, ns)
# decoded background
ptp_utils.register_attention_control(self.model, orig_controller, center_row_rm, center_col_rm, target_height, target_width,
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
orig, orig_controller = self.low_order_sample(x[0], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=orig_controller)
# decoded reference
ptp_utils.register_attention_control(self.model, ref_controller, center_row_rm, center_col_rm, target_height, target_width,
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
ref, ref_controller = self.low_order_sample(x[3], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=ref_controller)
# decode for cross-attention
ptp_utils.register_attention_control(self.model, cross_controller, center_row_rm, center_col_rm, target_height, target_width,
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), pseudo_cross=True)
cross, cross_controller = self.low_order_sample(x[2], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode,
controller=cross_controller, ref_init=ref['x'].clone())
# generation
Inject_controller = [orig_controller, ref_controller, cross_controller]
ptp_utils.register_attention_control(self.model, gen_controller, center_row_rm, center_col_rm, target_height, target_width,
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), inject_bg=True)
gen, _ = self.low_order_sample(x[4], dpm_solver_gen, steps, order, t_start, t_end, device,
DPMencode=DPMencode, controller=Inject_controller, inject=True)
for i in range(len(orig['model_prev_list'])):
blended = orig['model_prev_list'][i].clone()
blended[:, :, param[0] : param[1], param[2] : param[3]] \
= gen['model_prev_list'][i][:, :, param[0] : param[1], param[2] : param[3]].clone()
gen['model_prev_list'][i] = blended.clone()
del orig_controller, ref_controller, cross_controller, gen_controller, Inject_controller
orig_controller = ptp.AttentionStore()
ref_controller = ptp.AttentionStore()
cross_controller = ptp.AttentionStore()
gen_controller = ptp.AttentionStore()
for step in range(order, steps + 1):
# decoded background
ptp_utils.register_attention_control(self.model, orig_controller, center_row_rm, center_col_rm, target_height, target_width,
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
orig = dpm_solver_decode.sample_one_step(orig, step, steps, order=order, DPMencode=DPMencode)
# decode for cross-attention
ptp_utils.register_attention_control(self.model, cross_controller, center_row_rm, center_col_rm, target_height, target_width,
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), pseudo_cross=True)
cross['x'] = orig['x']
cross = dpm_solver_decode.sample_one_step(cross, step, steps, order=order, DPMencode=DPMencode, ref_init=ref['x'].clone())
if step < int(tau_a*(steps) + 1 - order):
inject = True
# decoded reference
ptp_utils.register_attention_control(self.model, ref_controller, center_row_rm, center_col_rm, target_height, target_width,
width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
ref = dpm_solver_decode.sample_one_step(ref, step, steps, order=order, DPMencode=DPMencode)
controller = [orig_controller, ref_controller, cross_controller]
else:
inject = False
controller = [orig_controller, None, cross_controller]
if step < int(0.4*(steps) + 1 - order):
inject_bg = True
else:
inject_bg = False
# generation
ptp_utils.register_attention_control(self.model, gen_controller, center_row_rm, center_col_rm, target_height, target_width, width, height,
top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), inject_bg=inject_bg)
gen = dpm_solver_gen.sample_one_step(gen, step, steps, order=order, DPMencode=DPMencode, controller=controller, inject=inject)
if step < int(tau_b*(steps) + 1 - order):
blended = orig['x'].clone()
blended[:, :, param[0] : param[1], param[2] : param[3]] \
= gen['x'][:, :, param[0] : param[1], param[2] : param[3]].clone()
gen['x'] = blended.clone()
del orig_controller, ref_controller, cross_controller, gen_controller, controller
return gen['x'].to(device), None
def low_order_sample(self, x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=False, controller=None, inject=False, ref_init=None):
t_0 = 1. / dpm_solver.noise_schedule.total_N if t_end is None else t_end
t_T = dpm_solver.noise_schedule.T if t_start is None else t_start
total_controller = []
assert steps >= order
timesteps = dpm_solver.get_time_steps(skip_type="time_uniform", t_T=t_T, t_0=t_0, N=steps, device=device, DPMencode=DPMencode)
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode,
controller=[controller[0][0], controller[1][0], controller[2][0]] if isinstance(controller, list) else controller,
inject=inject, ref_init=ref_init)]
total_controller.append(controller)
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
x = dpm_solver.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
solver_type='dpmsolver', DPMencode=DPMencode)
model_prev_list.append(dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode,
controller=[controller[0][init_order], controller[1][init_order], controller[2][init_order]] if isinstance(controller, list) else controller,
inject=inject, ref_init=ref_init))
total_controller.append(controller)
t_prev_list.append(vec_t)
return {'x': x, 'model_prev_list': model_prev_list, 't_prev_list': t_prev_list, 'timesteps':timesteps}, total_controller