import os import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import sys import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import gc import warnings warnings.filterwarnings("ignore") from PIL import Image import numpy as np from editing import get_direction, debias from sampling import sample_weights from lora_w2w import LoRAw2w from transformers import CLIPTextModel from lora_w2w import LoRAw2w from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler from transformers import AutoTokenizer, PretrainedConfig from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline ) from huggingface_hub import snapshot_download import spaces models_path = snapshot_download(repo_id="Snapchat/w2w") @spaces.GPU def load_models(device): pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" revision = None rank = 1 weight_dtype = torch.bfloat16 # Load scheduler, tokenizer and models. pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", torch_dtype=torch.float16,safety_checker = None, requires_safety_checker = False).to(device) noise_scheduler = pipe.scheduler del pipe tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", revision=revision ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision ) vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", revision=revision ) unet.requires_grad_(False) unet.to(device, dtype=weight_dtype) vae.requires_grad_(False) text_encoder.requires_grad_(False) vae.requires_grad_(False) vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) print("") return unet, vae, text_encoder, tokenizer, noise_scheduler class main(): def __init__(self): super(main, self).__init__() device = "cuda" mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device) std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device) v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device) proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) df = torch.load(f"{models_path}/files/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) self.device = device self.mean = mean self.std = std self.v = v self.proj = proj self.df = df self.weight_dimensions = weight_dimensions self.pinverse = pinverse pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" revision = None rank = 1 weight_dtype = torch.bfloat16 # Load scheduler, tokenizer and models. pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", torch_dtype=torch.float16,safety_checker = None, requires_safety_checker = False).to(device) self.noise_scheduler = pipe.scheduler del pipe self.tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", revision=revision ) self.text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision ) self.vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) self.unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", revision=revision ) self.unet.requires_grad_(False) self.unet.to(device, dtype=weight_dtype) self.vae.requires_grad_(False) self.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) self.vae.to(device, dtype=weight_dtype) self.text_encoder.to(device, dtype=weight_dtype) print("") self.network = None young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young = debias(young, "Pointy_Nose", df, pinverse, device) young = debias(young, "Wavy_Hair", df, pinverse, device) young = debias(young, "Chubby", df, pinverse, device) young = debias(young, "No_Beard", df, pinverse, device) young = debias(young, "Mustache", df, pinverse, device) self.young = young pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy = debias(pointy, "Young", df, pinverse, device) pointy = debias(pointy, "Male", df, pinverse, device) pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) pointy = debias(pointy, "Chubby", df, pinverse, device) pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) self.pointy = pointy wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) wavy = debias(wavy, "Young", df, pinverse, device) wavy = debias(wavy, "Male", df, pinverse, device) wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) wavy = debias(wavy, "Chubby", df, pinverse, device) wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) self.wavy = wavy thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) thick = debias(thick, "Male", df, pinverse, device) thick = debias(thick, "Young", df, pinverse, device) thick = debias(thick, "Pointy_Nose", df, pinverse, device) thick = debias(thick, "Wavy_Hair", df, pinverse, device) thick = debias(thick, "Mustache", df, pinverse, device) thick = debias(thick, "No_Beard", df, pinverse, device) thick = debias(thick, "Sideburns", df, pinverse, device) thick = debias(thick, "Big_Nose", df, pinverse, device) thick = debias(thick, "Big_Lips", df, pinverse, device) thick = debias(thick, "Black_Hair", df, pinverse, device) thick = debias(thick, "Brown_Hair", df, pinverse, device) thick = debias(thick, "Pale_Skin", df, pinverse, device) thick = debias(thick, "Heavy_Makeup", df, pinverse, device) self.thick = thick @torch.no_grad() @spaces.GPU(duration=1000) def sample_model(self): self.unet, _, _, _, _ = load_models(self.device) self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00) @torch.no_grad() @spaces.GPU(duration=1000) def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed): device = self.device self.unet.to(device) self.text_encoder.to(device) self.vae.to(device) self.network.to(device) generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, self.unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = self.device ).bfloat16() text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = self.tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) self.noise_scheduler.set_timesteps(ddim_steps) latents = latents * self.noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t) with self.network: print(latent_model_input.device) print(self.unet.device) print(self.text_encoder.device) print(self.vae.device) print(self.network.proj.device) print(self.network.mean.device) print(self.network.std.device) print(self.network.v.device) print(text_embeddings.device) noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample print("after inference") #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return image @torch.no_grad() @spaces.GPU def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): device = self.device original_weights = self,network.proj.clone() #pad to same number of PCs pcs_original = original_weights.shape[1] pcs_edits = self.young.shape[1] padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) young_pad = torch.cat((self.young, padding), 1) pointy_pad = torch.cat((self.pointy, padding), 1) wavy_pad = torch.cat((self.wavy, padding), 1) thick_pad = torch.cat((self.thick, padding), 1) edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, self.unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = self.device ).bfloat16() text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: self.network.proj = torch.nn.Parameter(edited_weights) self.network.reset() with self.network: noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) #reset weights back to original self.network.proj = torch.nn.Parameter(original_weights) self.network.reset() return image @spaces.GPU def sample_then_run(self): self.sample_model() prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 25 image = self.inference( prompt, negative_prompt, cfg, steps, seed) torch.save(self.network.proj, "model.pt" ) return image, "model.pt" class CustomImageDataset(Dataset): def __init__(self, images, transform=None): self.images = images self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) return image @spaces.GPU def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): del unet del network unet, _, _, _, _ = load_models(device) proj = torch.zeros(1,pcs).bfloat16().to(device) network = LoRAw2w( proj, mean, std, v[:, :pcs], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) ### load mask mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) ### check if an actual mask was draw, otherwise mask is just all ones if torch.sum(mask) == 0: mask = torch.ones((1,1,64,64)).to(device).bfloat16() ### single image dataset image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_dataset = CustomImageDataset(image, transform=image_transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) ### optimizer optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) ### training loop unet.train() for epoch in tqdm.tqdm(range(epochs)): for batch in train_dataloader: ### prepare inputs batch = batch.to(device).bfloat16() latents = vae.encode(batch).latent_dist.sample() latents = latents*0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] ### loss + sgd step with network: model_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") optim.zero_grad() loss.backward() optim.step() ### return optimized network return network @spaces.GPU def run_inversion(self, dict, pcs, epochs, weight_decay,lr): init_image = dict["image"].convert("RGB").resize((512, 512)) mask = dict["mask"].convert("RGB").resize((512, 512)) network = invert([init_image], mask, pcs, epochs, weight_decay,lr) #sample an image prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity" seed = 5 cfg = 3.0 steps = 25 image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return image, "model.pt" @spaces.GPU def file_upload(self, file): proj = torch.load(file.name).to(device) #pad to 10000 Principal components to keep everything consistent pcs = proj.shape[1] padding = torch.zeros((1,10000-pcs)).to(device) proj = torch.cat((proj, padding), 1) unet, _, _, _, _ = load_models(device) network = LoRAw2w( proj, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity" seed = 5 cfg = 3.0 steps = 25 image = inference( prompt, negative_prompt, cfg, steps, seed) return image intro = """
Project Page | Paper | Code |
""" with gr.Blocks(css="style.css") as demo: model = main() gr.HTML(intro) gr.Markdown("""