Spaces:
Sleeping
Sleeping
# Adapted from https://github.com/MichalGeyer/pnp-diffusers/blob/main/pnp.py | |
import spaces | |
import glob | |
import os | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
import argparse | |
from PIL import Image | |
import yaml | |
from tqdm import tqdm | |
from transformers import logging | |
from diffusers import DDIMScheduler, StableDiffusionPipeline | |
from pnp_utils import * | |
from unet2d_custom import UNet2DConditionModel | |
from pipeline_stable_diffusion_custom import StableDiffusionPipeline | |
from ldm.modules.encoders.audio_projector_res import Adapter | |
# suppress partial model loading warning | |
logging.set_verbosity_error() | |
from diffusers import logging | |
logging.set_verbosity_error() | |
class PNP(nn.Module): | |
def __init__(self, sd_version="1.4", n_timesteps=50, audio_projector_ckpt_path="ckpts/audio_projector_gh.pth", | |
adapter_ckpt_path="ckpts/greatest_hits.pt", device="cuda", | |
clap_path="CLAP/msclap", | |
clap_weights = "ckpts/CLAP_weights_2022.pth", | |
): | |
super().__init__() | |
self.device = device | |
if sd_version == '2.1': | |
model_key = "stabilityai/stable-diffusion-2-1-base" | |
elif sd_version == '2.0': | |
model_key = "stabilityai/stable-diffusion-2-base" | |
elif sd_version == '1.5': | |
model_key = "runwayml/stable-diffusion-v1-5" | |
elif sd_version == '1.4': | |
model_key = "CompVis/stable-diffusion-v1-4" | |
print(f"model key is {model_key}") | |
else: | |
raise ValueError(f'Stable-diffusion version {sd_version} not supported.') | |
# Create SD models | |
print('Loading SD model') | |
pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.float16).to("cuda") | |
model_id = "CompVis/stable-diffusion-v1-4" | |
self.unet = UNet2DConditionModel.from_pretrained( | |
model_id, | |
subfolder="unet", | |
use_adapter_list=[False, True, True], | |
low_cpu_mem_usage=False, | |
device_map=None | |
).to("cuda") | |
audio_projector_path = "ckpts/audio_projector_landscape.pth" | |
adapter_ckpt_path = "ckpts/landscape.pt" | |
#self.pnp.set_audio_projector(gate_dict_path, audio_projector_path) | |
gate_dict = torch.load(adapter_ckpt_path) | |
for name, param in self.unet.named_parameters(): | |
if "adapter" in name: | |
param.data = gate_dict[name] | |
#unet.to(self.device); | |
#pipe.unet = unet.to(self.device); | |
self.vae = pipe.vae | |
self.tokenizer = pipe.tokenizer | |
self.text_encoder = pipe.text_encoder | |
# self.unet = unet.to(self.device); | |
#pipe.unet | |
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | |
self.scheduler.set_timesteps(n_timesteps, device=self.device) | |
self.latents_path = "latents_forward" | |
self.output_path = "PNP-results/home" | |
import os | |
os.makedirs(self.output_path, exist_ok=True) | |
import sys | |
sys.path.append(clap_path) | |
from CLAPWrapper import CLAPWrapper | |
self.audio_encoder = CLAPWrapper(clap_weights, use_cuda=True) | |
self.audio_projector = Adapter(audio_token_count=77, transformer_layer_count=4).cuda() | |
#self.audio_projector_ckpt_path = audio_projector_ckpt_path | |
self.sr = 44100 | |
# self.set_audio_projector(adapter_ckpt_path, audio_projector_ckpt_path) | |
self.text_encoder = self.text_encoder.cuda() | |
self.audio_projector.load_state_dict(torch.load(audio_projector_path)) | |
self.audio_projector_ckpt_path = audio_projector_ckpt_path | |
self.adapter_ckpt_path = adapter_ckpt_path | |
self.changed_model = False | |
def set_audio_projector(self, adapter_ckpt_path, audio_projector_ckpt_path): | |
print(f"SETTING MODEL TO {adapter_ckpt_path}") | |
gate_dict = torch.load(adapter_ckpt_path) | |
for name, param in self.unet.named_parameters(): | |
if "adapter" in name: | |
param.data = gate_dict[name] | |
self.unet.eval() | |
self.unet = self.unet.cuda() | |
self.audio_projector.load_state_dict(torch.load(audio_projector_ckpt_path)) | |
self.audio_projector.eval() | |
self.audio_projector = self.audio_projector.cuda() | |
def set_text_embeds(self, prompt, negative_prompt=""): | |
self.text_encoder = self.text_encoder.cuda() | |
self.text_embeds = self.get_text_embeds(prompt, negative_prompt) | |
self.pnp_guidance_embeds = self.get_text_embeds("", "").chunk(2)[0] | |
def set_audio_context(self, audio_path): | |
self.audio_projector = self.audio_projector.cuda() | |
self.audio_encoder.clap.audio_encoder = self.audio_encoder.clap.audio_encoder.to("cuda") | |
audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio_path], resample = self.sr) | |
dtpye_w = self.audio_projector.audio_emb_projection[0].weight.dtype | |
device_w = self.audio_projector.audio_emb_projection[0].weight.device | |
audio_emb = audio_emb.cuda() | |
audio_proj = self.audio_projector(audio_emb.unsqueeze(1)) | |
audio_emb = torch.zeros(1, 1024).cuda() | |
audio_uc = self.audio_projector(audio_emb.unsqueeze(1)) | |
self.audio_context = torch.cat([audio_uc, audio_uc, audio_proj]).cuda() | |
def get_text_embeds(self, prompt, negative_prompt, batch_size=1): | |
# Tokenize text and get embeddings | |
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
truncation=True, return_tensors='pt') | |
input_ids = text_input.input_ids.to("cuda") | |
text_embeddings = self.text_encoder(input_ids)[0] | |
# Do the same for unconditional embeddings | |
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
return_tensors='pt') | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
# Cat for final embeddings | |
text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size) | |
return text_embeddings | |
def decode_latent(self, latent): | |
self.vae = self.vae.cuda() | |
with torch.autocast(device_type='cuda', dtype=torch.float32): | |
latent = 1 / 0.18215 * latent | |
img = self.vae.decode(latent).sample | |
img = (img / 2 + 0.5).clamp(0, 1) | |
return img | |
#@torch.autocast(device_type='cuda', dtype=torch.float32) | |
def get_data(self, image_path): | |
self.image_path = image_path | |
# load image | |
image = Image.open(image_path).convert('RGB') | |
image = image.resize((512, 512), resample=Image.Resampling.LANCZOS) | |
image = T.ToTensor()(image).to(self.device) | |
# get noise | |
latents_path = os.path.join(self.latents_path, f'noisy_latents_{self.scheduler.timesteps[0]}.pt') | |
noisy_latent = torch.load(latents_path).to(self.device) | |
return image, noisy_latent | |
def denoise_step(self, x, t, guidance_scale): | |
# register the time step and features in pnp injection modules | |
source_latents = load_source_latents_t(t, os.path.join(self.latents_path)) | |
latent_model_input = torch.cat([source_latents] + ([x] * 2)) | |
register_time(self, t.item()) | |
# compute text embeddings | |
text_embed_input = torch.cat([self.pnp_guidance_embeds, self.text_embeds], dim=0) | |
# apply the denoising network | |
noise_pred = self.unet(latent_model_input, t, | |
encoder_hidden_states=text_embed_input, | |
audio_context=self.audio_context)['sample'] | |
# perform guidance | |
_, noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
# compute the denoising step with the reference model | |
denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample'] | |
return denoised_latent | |
def init_pnp(self, conv_injection_t, qk_injection_t): | |
self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else [] | |
self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else [] | |
register_attention_control_efficient(self, self.qk_injection_timesteps) | |
register_conv_control_efficient(self, self.conv_injection_timesteps) | |
def run_pnp(self, n_timesteps=50, pnp_f_t=0.5, pnp_attn_t=0.5, | |
prompt="", negative_prompt="", | |
audio_path="", image_path="", | |
cfg_scale=5): | |
# if not self.changed_model: | |
# self.set_audio_projector(self.adapter_ckpt_path, self.audio_projector_ckpt_path) | |
self.audio_projector = self.audio_projector.cuda() | |
self.set_text_embeds(prompt) | |
self.set_audio_context(audio_path=audio_path) | |
self.image, self.eps = self.get_data(image_path=image_path) | |
self.unet = self.unet.cuda() | |
pnp_f_t = int(n_timesteps * pnp_f_t) | |
pnp_attn_t = int(n_timesteps * pnp_attn_t) | |
self.init_pnp(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t) | |
edited_img = self.sample_loop(self.eps, cfg_scale=cfg_scale) | |
return T.ToPILImage()(edited_img[0]) | |
def sample_loop(self, x, cfg_scale): | |
with torch.autocast(device_type='cuda', dtype=torch.float32): | |
for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")): | |
x = self.denoise_step(x, t, cfg_scale) | |
decoded_latent = self.decode_latent(x) | |
T.ToPILImage()(decoded_latent[0]).save(f'{self.output_path}/output.png') | |
return decoded_latent | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config_path', type=str, default='config_pnp.yaml') | |
opt = parser.parse_args() | |
with open(opt.config_path, "r") as f: | |
config = yaml.safe_load(f) | |
os.makedirs(config["output_path"], exist_ok=True) | |
with open(os.path.join(config["output_path"], "config.yaml"), "w") as f: | |
yaml.dump(config, f) | |
seed_everything(config["seed"]) | |
print(config) | |
pnp = PNP(config) | |
temp = pnp.run_pnp() |