|
|
|
import os |
|
import requests |
|
import sys |
|
import pdb |
|
import copy |
|
from tqdm import tqdm |
|
import torch |
|
from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel |
|
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler |
|
from diffusers.utils.peft_utils import set_weights_and_activate_adapters |
|
from peft import LoraConfig |
|
|
|
from pipelines.pix2pix.model import ( |
|
make_1step_sched, |
|
my_vae_encoder_fwd, |
|
my_vae_decoder_fwd, |
|
) |
|
|
|
|
|
class TwinConv(torch.nn.Module): |
|
def __init__(self, convin_pretrained, convin_curr): |
|
super(TwinConv, self).__init__() |
|
self.conv_in_pretrained = copy.deepcopy(convin_pretrained) |
|
self.conv_in_curr = copy.deepcopy(convin_curr) |
|
self.r = None |
|
|
|
def forward(self, x): |
|
x1 = self.conv_in_pretrained(x).detach() |
|
x2 = self.conv_in_curr(x) |
|
return x1 * (1 - self.r) + x2 * (self.r) |
|
|
|
|
|
class Pix2Pix_Turbo(torch.nn.Module): |
|
def __init__(self, name, ckpt_folder="checkpoints"): |
|
super().__init__() |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"stabilityai/sd-turbo", subfolder="tokenizer" |
|
) |
|
self.text_encoder = CLIPTextModel.from_pretrained( |
|
"stabilityai/sd-turbo", subfolder="text_encoder" |
|
).cuda() |
|
self.sched = make_1step_sched() |
|
|
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae") |
|
unet = UNet2DConditionModel.from_pretrained( |
|
"stabilityai/sd-turbo", subfolder="unet" |
|
) |
|
|
|
if name == "edge_to_image": |
|
url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl" |
|
os.makedirs(ckpt_folder, exist_ok=True) |
|
outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl") |
|
if not os.path.exists(outf): |
|
print(f"Downloading checkpoint to {outf}") |
|
response = requests.get(url, stream=True) |
|
total_size_in_bytes = int(response.headers.get("content-length", 0)) |
|
block_size = 1024 |
|
progress_bar = tqdm( |
|
total=total_size_in_bytes, unit="iB", unit_scale=True |
|
) |
|
with open(outf, "wb") as file: |
|
for data in response.iter_content(block_size): |
|
progress_bar.update(len(data)) |
|
file.write(data) |
|
progress_bar.close() |
|
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: |
|
print("ERROR, something went wrong") |
|
print(f"Downloaded successfully to {outf}") |
|
p_ckpt = outf |
|
sd = torch.load(p_ckpt, map_location="cpu") |
|
unet_lora_config = LoraConfig( |
|
r=sd["rank_unet"], |
|
init_lora_weights="gaussian", |
|
target_modules=sd["unet_lora_target_modules"], |
|
) |
|
|
|
if name == "sketch_to_image_stochastic": |
|
|
|
url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl" |
|
os.makedirs(ckpt_folder, exist_ok=True) |
|
outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl") |
|
if not os.path.exists(outf): |
|
print(f"Downloading checkpoint to {outf}") |
|
response = requests.get(url, stream=True) |
|
total_size_in_bytes = int(response.headers.get("content-length", 0)) |
|
block_size = 1024 |
|
progress_bar = tqdm( |
|
total=total_size_in_bytes, unit="iB", unit_scale=True |
|
) |
|
with open(outf, "wb") as file: |
|
for data in response.iter_content(block_size): |
|
progress_bar.update(len(data)) |
|
file.write(data) |
|
progress_bar.close() |
|
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: |
|
print("ERROR, something went wrong") |
|
print(f"Downloaded successfully to {outf}") |
|
p_ckpt = outf |
|
sd = torch.load(p_ckpt, map_location="cpu") |
|
unet_lora_config = LoraConfig( |
|
r=sd["rank_unet"], |
|
init_lora_weights="gaussian", |
|
target_modules=sd["unet_lora_target_modules"], |
|
) |
|
convin_pretrained = copy.deepcopy(unet.conv_in) |
|
unet.conv_in = TwinConv(convin_pretrained, unet.conv_in) |
|
|
|
vae.encoder.forward = my_vae_encoder_fwd.__get__( |
|
vae.encoder, vae.encoder.__class__ |
|
) |
|
vae.decoder.forward = my_vae_decoder_fwd.__get__( |
|
vae.decoder, vae.decoder.__class__ |
|
) |
|
|
|
vae.decoder.skip_conv_1 = torch.nn.Conv2d( |
|
512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False |
|
).cuda() |
|
vae.decoder.skip_conv_2 = torch.nn.Conv2d( |
|
256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False |
|
).cuda() |
|
vae.decoder.skip_conv_3 = torch.nn.Conv2d( |
|
128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False |
|
).cuda() |
|
vae.decoder.skip_conv_4 = torch.nn.Conv2d( |
|
128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False |
|
).cuda() |
|
vae_lora_config = LoraConfig( |
|
r=sd["rank_vae"], |
|
init_lora_weights="gaussian", |
|
target_modules=sd["vae_lora_target_modules"], |
|
) |
|
vae.decoder.ignore_skip = False |
|
vae.add_adapter(vae_lora_config, adapter_name="vae_skip") |
|
unet.add_adapter(unet_lora_config) |
|
_sd_unet = unet.state_dict() |
|
for k in sd["state_dict_unet"]: |
|
_sd_unet[k] = sd["state_dict_unet"][k] |
|
unet.load_state_dict(_sd_unet) |
|
unet.enable_xformers_memory_efficient_attention() |
|
_sd_vae = vae.state_dict() |
|
for k in sd["state_dict_vae"]: |
|
_sd_vae[k] = sd["state_dict_vae"][k] |
|
vae.load_state_dict(_sd_vae) |
|
unet.to("cuda") |
|
vae.to("cuda") |
|
unet.eval() |
|
vae.eval() |
|
self.unet, self.vae = unet, vae |
|
self.vae.decoder.gamma = 1 |
|
self.timesteps = torch.tensor([999], device="cuda").long() |
|
self.last_prompt = "" |
|
self.caption_enc = None |
|
self.device = "cuda" |
|
|
|
@torch.no_grad() |
|
def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=1.0): |
|
|
|
if prompt != self.last_prompt: |
|
caption_tokens = self.tokenizer( |
|
prompt, |
|
max_length=self.tokenizer.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
).input_ids.cuda() |
|
caption_enc = self.text_encoder(caption_tokens)[0] |
|
self.caption_enc = caption_enc |
|
self.last_prompt = prompt |
|
|
|
if deterministic: |
|
encoded_control = ( |
|
self.vae.encode(c_t).latent_dist.sample() |
|
* self.vae.config.scaling_factor |
|
) |
|
model_pred = self.unet( |
|
encoded_control, |
|
self.timesteps, |
|
encoder_hidden_states=self.caption_enc, |
|
).sample |
|
x_denoised = self.sched.step( |
|
model_pred, self.timesteps, encoded_control, return_dict=True |
|
).prev_sample |
|
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks |
|
output_image = ( |
|
self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
|
).clamp(-1, 1) |
|
else: |
|
|
|
self.unet.set_adapters(["default"], weights=[r]) |
|
set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r]) |
|
encoded_control = ( |
|
self.vae.encode(c_t).latent_dist.sample() |
|
* self.vae.config.scaling_factor |
|
) |
|
|
|
unet_input = encoded_control * r + noise_map * (1 - r) |
|
self.unet.conv_in.r = r |
|
unet_output = self.unet( |
|
unet_input, |
|
self.timesteps, |
|
encoder_hidden_states=self.caption_enc, |
|
).sample |
|
self.unet.conv_in.r = None |
|
x_denoised = self.sched.step( |
|
unet_output, self.timesteps, unet_input, return_dict=True |
|
).prev_sample |
|
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks |
|
self.vae.decoder.gamma = r |
|
output_image = ( |
|
self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
|
).clamp(-1, 1) |
|
return output_image |
|
|