Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import PIL.Image | |
from tqdm import tqdm | |
from typing import Optional, Union, List | |
import warnings | |
warnings.filterwarnings('ignore') | |
from torch.optim.adam import Adam | |
import torch.nn.functional as nnf | |
from diffusers import DDIMScheduler | |
########## | |
# helper # | |
########## | |
def diffusion_step(model, latents, context, t, guidance_scale, low_resource=False): | |
if low_resource: | |
noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] | |
noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] | |
else: | |
latents_input = torch.cat([latents] * 2) | |
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] | |
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | |
latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] | |
return latents | |
def image2latent(vae, image): | |
with torch.no_grad(): | |
if isinstance(image, PIL.Image.Image): | |
image = np.array(image) | |
if isinstance(image, np.ndarray): | |
dtype = next(vae.parameters()).dtype | |
device = next(vae.parameters()).device | |
image = torch.from_numpy(image).float() / 127.5 - 1 | |
image = image.permute(2, 0, 1).unsqueeze(0).to(device=device, dtype=dtype) | |
latents = vae.encode(image)['latent_dist'].mean | |
latents = latents * 0.18215 | |
return latents | |
def latent2image(vae, latents, return_type='np'): | |
assert isinstance(latents, torch.Tensor) | |
latents = 1 / 0.18215 * latents.detach() | |
image = vae.decode(latents)['sample'] | |
if return_type in ['np', 'pil']: | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
image = (image * 255).astype(np.uint8) | |
if return_type == 'pil': | |
pilim = [PIL.Image.fromarray(imi) for imi in image] | |
pilim = pilim[0] if len(pilim)==1 else pilim | |
return pilim | |
else: | |
return image | |
def init_latent(latent, model, height, width, generator, batch_size): | |
if latent is None: | |
latent = torch.randn( | |
(1, model.unet.in_channels, height // 8, width // 8), | |
generator=generator, | |
) | |
latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) | |
return latent, latents | |
def txt_to_emb(model, prompt): | |
text_input = model.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt",) | |
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
return text_embeddings | |
def text2image_ldm( | |
model, | |
prompt: List[str], | |
num_inference_steps: int = 50, | |
guidance_scale: Optional[float] = 7.5, | |
generator: Optional[torch.Generator] = None, | |
latent: Optional[torch.FloatTensor] = None, | |
uncond_embeddings=None, | |
start_time=50, | |
return_type='pil', ): | |
batch_size = len(prompt) | |
height = width = 512 | |
if latent is not None: | |
height = latent.shape[-2] * 8 | |
width = latent.shape[-1] * 8 | |
text_input = model.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt",) | |
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
if uncond_embeddings is None: | |
uncond_input = model.tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt",) | |
uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0] | |
else: | |
uncond_embeddings_ = None | |
latent, latents = init_latent(latent, model, height, width, generator, batch_size) | |
model.scheduler.set_timesteps(num_inference_steps) | |
for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])): | |
if uncond_embeddings_ is None: | |
context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) | |
else: | |
context = torch.cat([uncond_embeddings_, text_embeddings]) | |
latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False) | |
if return_type in ['pil', 'np']: | |
image = latent2image(model.vae, latents, return_type=return_type) | |
else: | |
image = latents | |
return image, latent | |
def text2image_ldm_imedit( | |
model, | |
thresh, | |
prompt: List[str], | |
target_prompt: List[str], | |
num_inference_steps: int = 50, | |
guidance_scale: Optional[float] = 7.5, | |
generator: Optional[torch.Generator] = None, | |
latent: Optional[torch.FloatTensor] = None, | |
uncond_embeddings=None, | |
start_time=50, | |
return_type='pil' | |
): | |
batch_size = len(prompt) | |
height = width = 512 | |
text_input = model.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
target_text_input = model.tokenizer( | |
target_prompt, | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
target_text_embeddings = model.text_encoder(target_text_input.input_ids.to(model.device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
if uncond_embeddings is None: | |
uncond_input = model.tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0] | |
else: | |
uncond_embeddings_ = None | |
latent, latents = init_latent(latent, model, height, width, generator, batch_size) | |
model.scheduler.set_timesteps(num_inference_steps) | |
for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])): | |
if i < (1 - thresh) * num_inference_steps: | |
if uncond_embeddings_ is None: | |
context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) | |
else: | |
context = torch.cat([uncond_embeddings_, text_embeddings]) | |
latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False) | |
else: | |
if uncond_embeddings_ is None: | |
context = torch.cat([uncond_embeddings[i].expand(*target_text_embeddings.shape), target_text_embeddings]) | |
else: | |
context = torch.cat([uncond_embeddings_, target_text_embeddings]) | |
latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False) | |
if return_type in ['pil', 'np']: | |
image = latent2image(model.vae, latents, return_type=return_type) | |
else: | |
image = latents | |
return image, latent | |
########### | |
# wrapper # | |
########### | |
class NullInversion(object): | |
def __init__(self, model, num_ddim_steps, guidance_scale, device='cuda'): | |
self.model = model | |
self.device = device | |
self.num_ddim_steps=num_ddim_steps | |
self.guidance_scale = guidance_scale | |
self.tokenizer = self.model.tokenizer | |
self.prompt = None | |
self.context = None | |
def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): | |
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | |
beta_prod_t = 1 - alpha_prod_t | |
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output | |
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction | |
return prev_sample | |
def next_step(self, noise_pred, timestep, sample): | |
timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod | |
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
next_original_sample = (sample - beta_prod_t ** 0.5 * noise_pred) / alpha_prod_t ** 0.5 | |
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * noise_pred | |
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction | |
return next_sample | |
def get_noise_pred_single(self, latents, t, context): | |
noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] | |
return noise_pred | |
def get_noise_pred(self, latents, t, is_forward=True, context=None): | |
latents_input = torch.cat([latents] * 2) | |
if context is None: | |
context = self.context | |
guidance_scale = 1 if is_forward else self.guidance_scale | |
noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] | |
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | |
if is_forward: | |
latents = self.next_step(noise_pred, t, latents) | |
else: | |
latents = self.prev_step(noise_pred, t, latents) | |
return latents | |
def init_prompt(self, prompt: str): | |
uncond_input = self.model.tokenizer( | |
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length, | |
return_tensors="pt" | |
) | |
uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] | |
text_input = self.model.tokenizer( | |
[prompt], | |
padding="max_length", | |
max_length=self.model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] | |
self.context = torch.cat([uncond_embeddings, text_embeddings]) | |
self.prompt = prompt | |
def ddim_loop(self, latent, emb): | |
# uncond_embeddings, cond_embeddings = self.context.chunk(2) | |
all_latent = [latent] | |
latent = latent.clone().detach() | |
for i in range(self.num_ddim_steps): | |
t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] | |
noise_pred = self.get_noise_pred_single(latent, t, emb) | |
latent = self.next_step(noise_pred, t, latent) | |
all_latent.append(latent) | |
return all_latent | |
def scheduler(self): | |
return self.model.scheduler | |
def ddim_invert(self, image, prompt): | |
assert isinstance(image, PIL.Image.Image) | |
scheduler_save = self.model.scheduler | |
self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config) | |
self.model.scheduler.set_timesteps(self.num_ddim_steps) | |
with torch.no_grad(): | |
emb = txt_to_emb(self.model, prompt) | |
latent = image2latent(self.model.vae, image) | |
ddim_latents = self.ddim_loop(latent, emb) | |
self.model.scheduler = scheduler_save | |
return ddim_latents[-1] | |
def null_optimization(self, latents, emb, nemb=None, num_inner_steps=10, epsilon=1e-5): | |
# force fp32 | |
dtype = latents[0].dtype | |
uncond_embeddings = nemb.float() if nemb is not None else txt_to_emb(self.model, "").float() | |
cond_embeddings = emb.float() | |
latents = [li.float() for li in latents] | |
self.model.unet.to(torch.float32) | |
uncond_embeddings_list = [] | |
latent_cur = latents[-1] | |
bar = tqdm(total=num_inner_steps * self.num_ddim_steps) | |
for i in range(self.num_ddim_steps): | |
uncond_embeddings = uncond_embeddings.clone().detach() | |
uncond_embeddings.requires_grad = True | |
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) | |
latent_prev = latents[len(latents) - i - 2] | |
t = self.model.scheduler.timesteps[i] | |
with torch.no_grad(): | |
noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) | |
for j in range(num_inner_steps): | |
noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) | |
loss = nnf.mse_loss(latents_prev_rec, latent_prev) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
loss_item = loss.item() | |
bar.update() | |
if loss_item < epsilon + i * 2e-5: | |
break | |
for j in range(j + 1, num_inner_steps): | |
bar.update() | |
uncond_embeddings_list.append(uncond_embeddings[:1].detach()) | |
with torch.no_grad(): | |
context = torch.cat([uncond_embeddings, cond_embeddings]) | |
latent_cur = self.get_noise_pred(latent_cur, t, False, context) | |
bar.close() | |
uncond_embeddings_list = [ui.to(dtype) for ui in uncond_embeddings_list] | |
self.model.unet.to(dtype) | |
return uncond_embeddings_list | |
def null_invert(self, im, txt, ntxt=None, num_inner_steps=10, early_stop_epsilon=1e-5): | |
assert isinstance(im, PIL.Image.Image) | |
scheduler_save = self.model.scheduler | |
self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config) | |
self.model.scheduler.set_timesteps(self.num_ddim_steps) | |
with torch.no_grad(): | |
nemb = txt_to_emb(self.model, ntxt) \ | |
if ntxt is not None else txt_to_emb(self.model, "") | |
emb = txt_to_emb(self.model, txt) | |
latent = image2latent(self.model.vae, im) | |
# ddim inversion | |
ddim_latents = self.ddim_loop(latent, emb) | |
# nulltext inversion | |
uncond_embeddings = self.null_optimization( | |
ddim_latents, emb, nemb, num_inner_steps, early_stop_epsilon) | |
self.model.scheduler = scheduler_save | |
return ddim_latents[-1], uncond_embeddings | |
def null_optimization_dual( | |
self, latents0, latents1, emb0, emb1, nemb=None, | |
num_inner_steps=10, epsilon=1e-5): | |
# force fp32 | |
dtype = latents0[0].dtype | |
uncond_embeddings = nemb.float() if nemb is not None else txt_to_emb(self.model, "").float() | |
cond_embeddings0, cond_embeddings1 = emb0.float(), emb1.float() | |
latents0 = [li.float() for li in latents0] | |
latents1 = [li.float() for li in latents1] | |
self.model.unet.to(torch.float32) | |
uncond_embeddings_list = [] | |
latent_cur0 = latents0[-1] | |
latent_cur1 = latents1[-1] | |
bar = tqdm(total=num_inner_steps * self.num_ddim_steps) | |
for i in range(self.num_ddim_steps): | |
uncond_embeddings = uncond_embeddings.clone().detach() | |
uncond_embeddings.requires_grad = True | |
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) | |
latent_prev0 = latents0[len(latents0) - i - 2] | |
latent_prev1 = latents1[len(latents1) - i - 2] | |
t = self.model.scheduler.timesteps[i] | |
with torch.no_grad(): | |
noise_pred_cond0 = self.get_noise_pred_single(latent_cur0, t, cond_embeddings0) | |
noise_pred_cond1 = self.get_noise_pred_single(latent_cur1, t, cond_embeddings1) | |
for j in range(num_inner_steps): | |
noise_pred_uncond0 = self.get_noise_pred_single(latent_cur0, t, uncond_embeddings) | |
noise_pred_uncond1 = self.get_noise_pred_single(latent_cur1, t, uncond_embeddings) | |
noise_pred0 = noise_pred_uncond0 + self.guidance_scale*(noise_pred_cond0-noise_pred_uncond0) | |
noise_pred1 = noise_pred_uncond1 + self.guidance_scale*(noise_pred_cond1-noise_pred_uncond1) | |
latents_prev_rec0 = self.prev_step(noise_pred0, t, latent_cur0) | |
latents_prev_rec1 = self.prev_step(noise_pred1, t, latent_cur1) | |
loss = nnf.mse_loss(latents_prev_rec0, latent_prev0) + \ | |
nnf.mse_loss(latents_prev_rec1, latent_prev1) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
loss_item = loss.item() | |
bar.update() | |
if loss_item < epsilon + i * 2e-5: | |
break | |
for j in range(j + 1, num_inner_steps): | |
bar.update() | |
uncond_embeddings_list.append(uncond_embeddings[:1].detach()) | |
with torch.no_grad(): | |
context0 = torch.cat([uncond_embeddings, cond_embeddings0]) | |
context1 = torch.cat([uncond_embeddings, cond_embeddings1]) | |
latent_cur0 = self.get_noise_pred(latent_cur0, t, False, context0) | |
latent_cur1 = self.get_noise_pred(latent_cur1, t, False, context1) | |
bar.close() | |
uncond_embeddings_list = [ui.to(dtype) for ui in uncond_embeddings_list] | |
self.model.unet.to(dtype) | |
return uncond_embeddings_list | |
def null_invert_dual( | |
self, im0, im1, txt0, txt1, ntxt=None, | |
num_inner_steps=10, early_stop_epsilon=1e-5, ): | |
assert isinstance(im0, PIL.Image.Image) | |
assert isinstance(im1, PIL.Image.Image) | |
scheduler_save = self.model.scheduler | |
self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config) | |
self.model.scheduler.set_timesteps(self.num_ddim_steps) | |
with torch.no_grad(): | |
nemb = txt_to_emb(self.model, ntxt) \ | |
if ntxt is not None else txt_to_emb(self.model, "") | |
latent0 = image2latent(self.model.vae, im0) | |
latent1 = image2latent(self.model.vae, im1) | |
emb0 = txt_to_emb(self.model, txt0) | |
emb1 = txt_to_emb(self.model, txt1) | |
# ddim inversion | |
ddim_latents_0 = self.ddim_loop(latent0, emb0) | |
ddim_latents_1 = self.ddim_loop(latent1, emb1) | |
# nulltext inversion | |
nembs = self.null_optimization_dual( | |
ddim_latents_0, ddim_latents_1, emb0, emb1, nemb, num_inner_steps, early_stop_epsilon) | |
self.model.scheduler = scheduler_save | |
return ddim_latents_0[-1], ddim_latents_1[-1], nembs | |