Spaces:
Runtime error
Runtime error
import torch | |
from diffusers import StableDiffusionPipeline | |
import numpy as np | |
import abc | |
import time_utils | |
import copy | |
import os | |
from train_funcs import TRAIN_FUNC_DICT | |
## get arguments for our script | |
with_to_k = True | |
with_augs = True | |
train_func = "train_closed_form" | |
### load model | |
LOW_RESOURCE = True | |
NUM_DIFFUSION_STEPS = 50 | |
GUIDANCE_SCALE = 7.5 | |
MAX_NUM_WORDS = 77 | |
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device) | |
tokenizer = ldm_stable.tokenizer | |
### get layers | |
ca_layers = [] | |
def append_ca(net_): | |
if net_.__class__.__name__ == 'CrossAttention': | |
ca_layers.append(net_) | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
append_ca(net__) | |
sub_nets = ldm_stable.unet.named_children() | |
for net in sub_nets: | |
if "down" in net[0]: | |
append_ca(net[1]) | |
elif "up" in net[0]: | |
append_ca(net[1]) | |
elif "mid" in net[0]: | |
append_ca(net[1]) | |
### get projection matrices | |
ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] | |
projection_matrices = [l.to_v for l in ca_clip_layers] | |
og_matrices = [copy.deepcopy(l.to_v) for l in ca_clip_layers] | |
if with_to_k: | |
projection_matrices = projection_matrices + [l.to_k for l in ca_clip_layers] | |
og_matrices = og_matrices + [copy.deepcopy(l.to_k) for l in ca_clip_layers] | |
def edit_model(old_text_, new_text_, lamb=0.1): | |
#### restart LDM parameters | |
num_ca_clip_layers = len(ca_clip_layers) | |
for idx_, l in enumerate(ca_clip_layers): | |
l.to_v = copy.deepcopy(og_matrices[idx_]) | |
projection_matrices[idx_] = l.to_v | |
if with_to_k: | |
l.to_k = copy.deepcopy(og_matrices[num_ca_clip_layers + idx_]) | |
projection_matrices[num_ca_clip_layers + idx_] = l.to_k | |
try: | |
#### set up sentences | |
old_texts = [old_text_] | |
new_texts = [new_text_] | |
if with_augs: | |
base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] | |
old_texts.append("A photo of " + base) | |
old_texts.append("An image of " + base) | |
old_texts.append("A picture of " + base) | |
base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] | |
new_texts.append("A photo of " + base) | |
new_texts.append("An image of " + base) | |
new_texts.append("A picture of " + base) | |
#### prepare input k* and v* | |
old_embs, new_embs = [], [] | |
for old_text, new_text in zip(old_texts, new_texts): | |
text_input = ldm_stable.tokenizer( | |
[old_text, new_text], | |
padding="max_length", | |
max_length=ldm_stable.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = ldm_stable.text_encoder(text_input.input_ids.to(ldm_stable.device))[0] | |
old_emb, new_emb = text_embeddings | |
old_embs.append(old_emb) | |
new_embs.append(new_emb) | |
#### indetify corresponding destinations for each token in old_emb | |
idxs_replaces = [] | |
for old_text, new_text in zip(old_texts, new_texts): | |
tokens_a = tokenizer(old_text).input_ids | |
tokens_b = tokenizer(new_text).input_ids | |
tokens_a = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_a] | |
tokens_b = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_b] | |
num_orig_tokens = len(tokens_a) | |
num_new_tokens = len(tokens_b) | |
idxs_replace = [] | |
j = 0 | |
for i in range(num_orig_tokens): | |
curr_token = tokens_a[i] | |
while tokens_b[j] != curr_token: | |
j += 1 | |
idxs_replace.append(j) | |
j += 1 | |
while j < 77: | |
idxs_replace.append(j) | |
j += 1 | |
while len(idxs_replace) < 77: | |
idxs_replace.append(76) | |
idxs_replaces.append(idxs_replace) | |
#### prepare batch: for each pair of setences, old context and new values | |
contexts, valuess = [], [] | |
for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): | |
context = old_emb.detach() | |
values = [] | |
with torch.no_grad(): | |
for layer in projection_matrices: | |
values.append(layer(new_emb[idxs_replace]).detach()) | |
contexts.append(context) | |
valuess.append(values) | |
#### define training function | |
train = TRAIN_FUNC_DICT[train_func] | |
#### train the model | |
train(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts, lamb=lamb) | |
return f"<b>Current model status:</b> Edited \"{old_text_}\" into \"{new_text_}\"" | |
except: | |
return "<b>Current model status:</b> An error occured" | |
def generate_for_text(test_text): | |
g = torch.Generator(device='cpu') | |
g.seed() | |
images = time_utils.text2image_ldm_stable(ldm_stable, [test_text], latent=None, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=g, low_resource=LOW_RESOURCE) | |
return time_utils.view_images(images) | |