# https://github.com/GaParmar/img2img-turbo/blob/main/src/pix2pix_turbo.py import os import requests import sys import pdb import copy from tqdm import tqdm import torch from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler from diffusers.utils.peft_utils import set_weights_and_activate_adapters from peft import LoraConfig from pipelines.pix2pix.model import ( make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, ) class TwinConv(torch.nn.Module): def __init__(self, convin_pretrained, convin_curr): super(TwinConv, self).__init__() self.conv_in_pretrained = copy.deepcopy(convin_pretrained) self.conv_in_curr = copy.deepcopy(convin_curr) self.r = None def forward(self, x): x1 = self.conv_in_pretrained(x).detach() x2 = self.conv_in_curr(x) return x1 * (1 - self.r) + x2 * (self.r) class Pix2Pix_Turbo(torch.nn.Module): def __init__(self, name, ckpt_folder="checkpoints"): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained( "stabilityai/sd-turbo", subfolder="tokenizer" ) self.text_encoder = CLIPTextModel.from_pretrained( "stabilityai/sd-turbo", subfolder="text_encoder" ).cuda() self.sched = make_1step_sched() vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae") unet = UNet2DConditionModel.from_pretrained( "stabilityai/sd-turbo", subfolder="unet" ) if name == "edge_to_image": url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl" os.makedirs(ckpt_folder, exist_ok=True) outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl") if not os.path.exists(outf): print(f"Downloading checkpoint to {outf}") response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm( total=total_size_in_bytes, unit="iB", unit_scale=True ) with open(outf, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") print(f"Downloaded successfully to {outf}") p_ckpt = outf sd = torch.load(p_ckpt, map_location="cpu") unet_lora_config = LoraConfig( r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"], ) if name == "sketch_to_image_stochastic": # download from url url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl" os.makedirs(ckpt_folder, exist_ok=True) outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl") if not os.path.exists(outf): print(f"Downloading checkpoint to {outf}") response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm( total=total_size_in_bytes, unit="iB", unit_scale=True ) with open(outf, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") print(f"Downloaded successfully to {outf}") p_ckpt = outf sd = torch.load(p_ckpt, map_location="cpu") unet_lora_config = LoraConfig( r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"], ) convin_pretrained = copy.deepcopy(unet.conv_in) unet.conv_in = TwinConv(convin_pretrained, unet.conv_in) vae.encoder.forward = my_vae_encoder_fwd.__get__( vae.encoder, vae.encoder.__class__ ) vae.decoder.forward = my_vae_decoder_fwd.__get__( vae.decoder, vae.decoder.__class__ ) # add the skip connection convs vae.decoder.skip_conv_1 = torch.nn.Conv2d( 512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False ).cuda() vae.decoder.skip_conv_2 = torch.nn.Conv2d( 256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False ).cuda() vae.decoder.skip_conv_3 = torch.nn.Conv2d( 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False ).cuda() vae.decoder.skip_conv_4 = torch.nn.Conv2d( 128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False ).cuda() vae_lora_config = LoraConfig( r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"], ) vae.decoder.ignore_skip = False vae.add_adapter(vae_lora_config, adapter_name="vae_skip") unet.add_adapter(unet_lora_config) _sd_unet = unet.state_dict() for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k] unet.load_state_dict(_sd_unet) unet.enable_xformers_memory_efficient_attention() _sd_vae = vae.state_dict() for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k] vae.load_state_dict(_sd_vae) unet.to("cuda") vae.to("cuda") unet.eval() vae.eval() self.unet, self.vae = unet, vae self.vae.decoder.gamma = 1 self.timesteps = torch.tensor([999], device="cuda").long() self.last_prompt = "" self.caption_enc = None self.device = "cuda" @torch.no_grad() def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=1.0): # encode the text prompt if prompt != self.last_prompt: caption_tokens = self.tokenizer( prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids.cuda() caption_enc = self.text_encoder(caption_tokens)[0] self.caption_enc = caption_enc self.last_prompt = prompt if deterministic: encoded_control = ( self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor ) model_pred = self.unet( encoded_control, self.timesteps, encoder_hidden_states=self.caption_enc, ).sample x_denoised = self.sched.step( model_pred, self.timesteps, encoded_control, return_dict=True ).prev_sample self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks output_image = ( self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample ).clamp(-1, 1) else: # scale the lora weights based on the r value self.unet.set_adapters(["default"], weights=[r]) set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r]) encoded_control = ( self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor ) # combine the input and noise unet_input = encoded_control * r + noise_map * (1 - r) self.unet.conv_in.r = r unet_output = self.unet( unet_input, self.timesteps, encoder_hidden_states=self.caption_enc, ).sample self.unet.conv_in.r = None x_denoised = self.sched.step( unet_output, self.timesteps, unet_input, return_dict=True ).prev_sample self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks self.vae.decoder.gamma = r output_image = ( self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample ).clamp(-1, 1) return output_image