WhiteWolf21's picture
Initialization
be13417
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import torch
import torch.nn.functional as F
import tqdm
from diffusers import (
AutoencoderKL,
ControlNetModel,
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from torch import nn
from transformers import CLIPTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
from lavis.common.registry import registry
from lavis.common.utils import download_and_untar, is_url
from lavis.models.base_model import BaseModel
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
from lavis.models.blip_diffusion_models.modeling_ctx_clip import CtxCLIPTextModel
from lavis.models.blip_diffusion_models.utils import numpy_to_pil, prepare_cond_image
from lavis.models.blip_diffusion_models.ptp_utils import (
LocalBlend,
P2PCrossAttnProcessor,
AttentionRefine,
)
class ProjLayer(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
super().__init__()
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
self.dense1 = nn.Linear(in_dim, hidden_dim)
self.act_fn = QuickGELU()
self.dense2 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(drop_p)
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
def forward(self, x):
x_in = x
x = self.LayerNorm(x)
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
return x
@registry.register_model("blip_diffusion")
class BlipDiffusion(BaseModel):
PRETRAINED_MODEL_CONFIG_DICT = {
"base": "configs/models/blip-diffusion/blip_diffusion_base.yaml",
"canny": "configs/models/blip-diffusion/blip_diffusion_controlnet_canny.yaml",
"depth": "configs/models/blip-diffusion/blip_diffusion_controlnet_depth.yaml",
"hed": "configs/models/blip-diffusion/blip_diffusion_controlnet_hed.yaml",
}
def __init__(
self,
vit_model="clip_L",
qformer_num_query_token=16,
qformer_cross_attention_freq=1,
qformer_pretrained_path=None,
qformer_train=False,
sd_pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
sd_train_text_encoder=False,
controlnet_pretrained_model_name_or_path=None,
vae_half_precision=False,
proj_train=False,
):
super().__init__()
self.num_query_token = qformer_num_query_token
# BLIP-2
self.blip = Blip2Qformer(
vit_model=vit_model,
num_query_token=qformer_num_query_token,
cross_attention_freq=qformer_cross_attention_freq,
)
if qformer_pretrained_path is not None:
state_dict = torch.load(qformer_pretrained_path, map_location="cpu")[
"model"
]
# qformer keys: Qformer.bert.encoder.layer.1.attention.self.key.weight
# ckpt keys: text_model.bert.encoder.layer.1.attention.self.key.weight
for k in list(state_dict.keys()):
if "text_model" in k:
state_dict[k.replace("text_model", "Qformer")] = state_dict.pop(k)
msg = self.blip.load_state_dict(state_dict, strict=False)
assert all(["visual" in k for k in msg.missing_keys])
assert len(msg.unexpected_keys) == 0
self.qformer_train = qformer_train
# projection layer
self.proj_layer = ProjLayer(
in_dim=768, out_dim=768, hidden_dim=3072, drop_p=0.1, eps=1e-12
)
self.proj_train = proj_train
# stable diffusion
self.tokenizer = CLIPTokenizer.from_pretrained(
sd_pretrained_model_name_or_path, subfolder="tokenizer"
)
self.text_encoder = CtxCLIPTextModel.from_pretrained(
sd_pretrained_model_name_or_path, subfolder="text_encoder"
)
self.vae = AutoencoderKL.from_pretrained(
sd_pretrained_model_name_or_path, subfolder="vae"
)
if vae_half_precision:
self.vae.half()
self.unet = UNet2DConditionModel.from_pretrained(
sd_pretrained_model_name_or_path, subfolder="unet"
)
# self.unet.enable_xformers_memory_efficient_attention()
self.noise_scheduler = DDPMScheduler.from_config(
sd_pretrained_model_name_or_path, subfolder="scheduler"
)
self.sd_train_text_encoder = sd_train_text_encoder
if controlnet_pretrained_model_name_or_path is not None:
self.controlnet = ControlNetModel.from_pretrained(
controlnet_pretrained_model_name_or_path
)
self.freeze_modules()
self.ctx_embeddings_cache = nn.Parameter(
torch.zeros(1, self.num_query_token, 768), requires_grad=False
)
self._use_embeddings_cache = False
# inference-related
self._CTX_BEGIN_POS = 2
def freeze_modules(self):
to_freeze = [self.vae]
if not self.sd_train_text_encoder:
to_freeze.append(self.text_encoder)
if not self.qformer_train:
to_freeze.append(self.blip)
if not self.proj_train:
to_freeze.append(self.proj_layer)
for module in to_freeze:
module.eval()
module.train = self.disabled_train
module.requires_grad_(False)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
@property
def pndm_scheduler(self):
if not hasattr(self, "_pndm_scheduler"):
self._pndm_scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
return self._pndm_scheduler
@property
def ddim_scheduler(self):
if not hasattr(self, "_ddim_scheduler"):
self._ddim_scheduler = DDIMScheduler.from_config(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler"
)
return self._ddim_scheduler
def before_training(self, dataset, **kwargs):
assert len(dataset) == 1, "Only support single dataset for now."
key = list(dataset.keys())[0]
dataset = dataset[key]["train"]
# collect all examples
# [FIXME] this is not memory efficient. may OOM if the dataset is large.
examples = [dataset[i] for i in range(dataset.len_without_repeat)]
input_images = (
torch.stack([example["inp_image"] for example in examples])
.to(memory_format=torch.contiguous_format)
.float()
).to(self.device)
subject_text = [dataset.subject for _ in range(input_images.shape[0])]
# calculate ctx embeddings and cache them
ctx_embeddings = self.forward_ctx_embeddings(
input_image=input_images, text_input=subject_text
)
# take mean of all ctx embeddings
ctx_embeddings = ctx_embeddings.mean(dim=0, keepdim=True)
self.ctx_embeddings_cache = nn.Parameter(ctx_embeddings, requires_grad=True)
self._use_embeddings_cache = True
# free up CUDA memory
self.blip.to("cpu")
self.proj_layer.to("cpu")
torch.cuda.empty_cache()
def forward(self, samples):
latents = self.vae.encode(samples["tgt_image"].half()).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bsz,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
ctx_embeddings = self.forward_ctx_embeddings(
input_image=samples["inp_image"], text_input=samples["subject_text"]
)
# Get the text embedding for conditioning
input_ids = self.tokenizer(
samples["caption"],
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids.to(self.device)
encoder_hidden_states = self.text_encoder(
input_ids=input_ids,
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=[self._CTX_BEGIN_POS] * input_ids.shape[0],
)[0]
# Predict the noise residual
noise_pred = self.unet(
noisy_latents.float(), timesteps, encoder_hidden_states
).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
return {"loss": loss}
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
def _build_prompts_edit(self, cond_subject, tgt_subject, prompt):
placeholder = " ".join(["sks"] * self.num_query_token)
src_prompt = f"a {cond_subject} {prompt}"
tgt_prompt = f"a {placeholder} {tgt_subject} {prompt}"
return [src_prompt, tgt_prompt]
def _predict_noise(
self,
t,
latent_model_input,
text_embeddings,
width=512,
height=512,
cond_image=None,
):
if hasattr(self, "controlnet"):
cond_image = prepare_cond_image(
cond_image, width, height, batch_size=1, device=self.device
)
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=cond_image,
# conditioning_scale=controlnet_condition_scale,
return_dict=False,
)
else:
down_block_res_samples, mid_block_res_sample = None, None
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)["sample"]
return noise_pred
def _init_latent(self, latent, height, width, generator, batch_size):
if latent is None:
latent = torch.randn(
(1, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=generator.device,
)
latent = latent.expand(
batch_size,
self.unet.in_channels,
height // 8,
width // 8,
)
return latent.to(self.device)
def _forward_prompt_embeddings(self, input_image, src_subject, prompt):
# 1. extract BLIP query features and proj to text space -> (bs, 32, 768)
query_embeds = self.forward_ctx_embeddings(input_image, src_subject)
# 2. embeddings for prompt, with query_embeds as context
tokenized_prompt = self._tokenize_text(prompt).to(self.device)
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=[self._CTX_BEGIN_POS],
)[0]
return text_embeddings
@torch.no_grad()
def get_image_latents(self, image, sample=True, rng_generator=None):
assert isinstance(image, torch.Tensor)
encoding_dist = self.vae.encode(image).latent_dist
if sample:
encoding = encoding_dist.sample(generator=rng_generator)
else:
encoding = encoding_dist.mode()
latents = encoding * 0.18215
return latents
def _inversion_transform(self, image, target_size=512):
from torchvision import transforms
tform = transforms.Compose(
[
transforms.Resize(target_size),
transforms.CenterCrop(target_size),
transforms.ToTensor(),
]
)
image = tform(image).unsqueeze(0).to(self.device)
return 2.0 * image - 1.0
@torch.no_grad()
def edit(
self,
samples,
lb_threshold=0.3,
guidance_scale=7.5,
height=512,
width=512,
seed=42,
num_inference_steps=50,
num_inversion_steps=50,
neg_prompt="",
):
raw_image = samples["raw_image"]
raw_image = self._inversion_transform(raw_image)
latents = self.get_image_latents(raw_image, rng_generator=None)
inv_latents = self._ddim_inverse(
samples=samples,
latents=latents,
seed=seed,
guidance_scale=1.0,
height=height,
width=width,
num_inference_steps=num_inversion_steps,
)
recon_image = self.generate_then_edit(
samples=samples,
latents=inv_latents,
seed=seed,
neg_prompt=neg_prompt,
guidance_scale=guidance_scale,
height=height,
width=width,
num_inference_steps=num_inference_steps,
use_inversion=True,
lb_threshold=lb_threshold,
)
return recon_image
@torch.no_grad()
def _ddim_inverse(
self,
samples,
latents,
guidance_scale=1.0,
height=512,
width=512,
seed=42,
num_inference_steps=50,
):
src_subject = samples["src_subject"] # source subject category
prompt = samples["prompt"]
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=src_subject,
prompt_strength=1.0,
prompt_reps=1,
)
tokenized_prompt = self._tokenize_text(prompt, with_query=False).to(self.device)
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=None,
)[0]
if seed is not None:
generator = torch.Generator(device=self.device)
generator = generator.manual_seed(seed)
latents = self._init_latent(latents, height, width, generator, batch_size=1)
scheduler = self.ddim_scheduler
# set timesteps
extra_set_kwargs = {}
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
iterator = tqdm.tqdm(reversed(scheduler.timesteps))
for i, t in enumerate(iterator):
latents = self._noise_latent_step(
latents=latents,
t=t,
text_embeddings=text_embeddings,
height=height,
width=width,
guidance_scale=guidance_scale,
)
return latents
@torch.no_grad()
def generate(
self,
samples,
latents=None,
guidance_scale=7.5,
height=512,
width=512,
seed=42,
num_inference_steps=50,
neg_prompt="",
controller=None,
prompt_strength=1.0,
prompt_reps=20,
use_ddim=False,
):
if controller is not None:
self._register_attention_refine(controller)
cond_image = samples["cond_images"] # reference image
cond_subject = samples["cond_subject"] # source subject category
tgt_subject = samples["tgt_subject"] # target subject category
prompt = samples["prompt"]
cldm_cond_image = samples.get("cldm_cond_image", None) # conditional image
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=tgt_subject,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
text_embeddings = self._forward_prompt_embeddings(
cond_image, cond_subject, prompt
)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt],
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if seed is not None:
generator = torch.Generator(device=self.device)
generator = generator.manual_seed(seed)
latents = self._init_latent(latents, height, width, generator, batch_size=1)
scheduler = self.pndm_scheduler if not use_ddim else self.ddim_scheduler
# set timesteps
extra_set_kwargs = {}
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
iterator = tqdm.tqdm(scheduler.timesteps)
for i, t in enumerate(iterator):
latents = self._denoise_latent_step(
latents=latents,
t=t,
text_embeddings=text_embeddings,
cond_image=cldm_cond_image,
height=height,
width=width,
guidance_scale=guidance_scale,
use_inversion=use_ddim,
)
image = self._latent_to_image(latents)
return image
def _register_attention_refine(
self,
src_subject,
prompts,
num_inference_steps,
cross_replace_steps=0.8,
self_replace_steps=0.4,
threshold=0.3,
):
device, tokenizer = self.device, self.tokenizer
lb = LocalBlend(
prompts=prompts,
words=(src_subject,),
device=device,
tokenizer=tokenizer,
threshold=threshold,
)
controller = AttentionRefine(
prompts,
num_inference_steps,
cross_replace_steps=cross_replace_steps,
self_replace_steps=self_replace_steps,
tokenizer=tokenizer,
device=device,
local_blend=lb,
)
self._register_attention_control(controller)
return controller
def _register_attention_control(self, controller):
attn_procs = {}
cross_att_count = 0
for name in self.unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else self.unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = self.unet.config.block_out_channels[-1]
place_in_unet = "mid"
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.unet.config.block_out_channels))[
block_id
]
place_in_unet = "up"
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.unet.config.block_out_channels[block_id]
place_in_unet = "down"
else:
continue
cross_att_count += 1
attn_procs[name] = P2PCrossAttnProcessor(
controller=controller, place_in_unet=place_in_unet
)
self.unet.set_attn_processor(attn_procs)
if controller is not None:
controller.num_att_layers = cross_att_count
@torch.no_grad()
def generate_then_edit(
self,
samples,
cross_replace_steps=0.8,
self_replace_steps=0.4,
guidance_scale=7.5,
height=512,
width=512,
latents=None,
seed=42,
num_inference_steps=250,
neg_prompt="",
use_inversion=False,
lb_threshold=0.3,
):
cond_image = samples["cond_images"] # reference image
cond_subject = samples["cond_subject"] # source subject category
src_subject = samples["src_subject"]
tgt_subject = samples["tgt_subject"] # target subject category
prompt = samples["prompt"]
assert len(prompt) == 1, "Do not support multiple prompts for now"
prompt = self._build_prompts_edit(src_subject, tgt_subject, prompt[0])
print(prompt)
controller = self._register_attention_refine(
src_subject=src_subject,
prompts=prompt,
num_inference_steps=num_inference_steps,
cross_replace_steps=cross_replace_steps,
self_replace_steps=self_replace_steps,
threshold=lb_threshold,
)
query_embeds = self.forward_ctx_embeddings(cond_image, cond_subject)
tokenized_prompt_bef = self._tokenize_text(prompt[:1], with_query=False).to(
self.device
)
tokenized_prompt_aft = self._tokenize_text(prompt[1:], with_query=True).to(
self.device
)
text_embeddings_bef = self.text_encoder(
input_ids=tokenized_prompt_bef.input_ids,
)[0]
text_embeddings_aft = self.text_encoder(
input_ids=tokenized_prompt_aft.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=[self._CTX_BEGIN_POS],
)[0]
text_embeddings = torch.cat([text_embeddings_bef, text_embeddings_aft], dim=0)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
# [TODO] add support for batched input
batch_size = 2
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt],
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
# FIXME use context embedding for uncond_input or not?
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# repeat the uncond embedding to match the number of prompts
uncond_embeddings = uncond_embeddings.expand(batch_size, -1, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if seed is not None:
generator = torch.Generator(device=self.device)
generator = generator.manual_seed(seed)
latents = self._init_latent(latents, height, width, generator, batch_size)
scheduler = self.pndm_scheduler if not use_inversion else self.ddim_scheduler
# set timesteps
scheduler.set_timesteps(num_inference_steps)
iterator = tqdm.tqdm(scheduler.timesteps)
for i, t in enumerate(iterator):
latents = self._denoise_latent_step(
latents=latents,
t=t,
text_embeddings=text_embeddings,
height=height,
width=width,
guidance_scale=guidance_scale,
use_inversion=use_inversion,
)
latents = controller.step_callback(latents)
image = self._latent_to_image(latents)
controller.reset()
return image
def _latent_to_image(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = numpy_to_pil(image)
return image
def _noise_latent_step(
self,
latents,
t,
text_embeddings,
guidance_scale,
height,
width,
):
def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt):
"""from noise to image"""
return (
alpha_tm1**0.5
* (
(alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
+ ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
)
+ x_t
)
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
# predict the noise residual
noise_pred = self._predict_noise(
t=t,
latent_model_input=latent_model_input,
text_embeddings=text_embeddings,
width=width,
height=height,
)
scheduler = self.ddim_scheduler
prev_timestep = (
t - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
)
alpha_prod_t = scheduler.alphas_cumprod[t]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else scheduler.final_alpha_cumprod
)
alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
latents = backward_ddim(
x_t=latents,
alpha_t=alpha_prod_t,
alpha_tm1=alpha_prod_t_prev,
eps_xt=noise_pred,
)
return latents
def _denoise_latent_step(
self,
latents,
t,
text_embeddings,
guidance_scale,
height,
width,
cond_image=None,
use_inversion=False,
):
if use_inversion:
noise_placeholder = []
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
# predict the noise residual
noise_pred = self._predict_noise(
t=t,
latent_model_input=latent_model_input,
text_embeddings=text_embeddings,
width=width,
height=height,
cond_image=cond_image,
)
if use_inversion:
noise_placeholder.append(noise_pred[2].unsqueeze(0))
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if use_inversion:
noise_placeholder.append(noise_pred[-1].unsqueeze(0))
noise_pred = torch.cat(noise_placeholder)
# compute the previous noisy sample x_t -> x_t-1
scheduler = self.ddim_scheduler if use_inversion else self.pndm_scheduler
latents = scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
return latents
def _tokenize_text(self, text_input, with_query=True):
max_len = self.text_encoder.text_model.config.max_position_embeddings
if with_query:
max_len -= self.num_query_token
tokenized_text = self.tokenizer(
text_input,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
)
return tokenized_text
def forward_ctx_embeddings(self, input_image, text_input, ratio=None):
def compute_ctx_embeddings(input_image, text_input):
# blip_embeddings = self.blip(image=input_image, text=text_input)
blip_embeddings = self.blip.extract_features(
{"image": input_image, "text_input": text_input}, mode="multimodal"
).multimodal_embeds
ctx_embeddings = self.proj_layer(blip_embeddings)
return ctx_embeddings
if isinstance(text_input, str):
text_input = [text_input]
if self._use_embeddings_cache:
# expand to batch size
ctx_embeddings = self.ctx_embeddings_cache.expand(len(text_input), -1, -1)
else:
if isinstance(text_input[0], str):
text_input, input_image = [text_input], [input_image]
all_ctx_embeddings = []
for inp_image, inp_text in zip(input_image, text_input):
ctx_embeddings = compute_ctx_embeddings(inp_image, inp_text)
all_ctx_embeddings.append(ctx_embeddings)
if ratio is not None:
assert len(ratio) == len(all_ctx_embeddings)
assert sum(ratio) == 1
else:
ratio = [1 / len(all_ctx_embeddings)] * len(all_ctx_embeddings)
ctx_embeddings = torch.zeros_like(all_ctx_embeddings[0])
for ratio, ctx_embeddings_ in zip(ratio, all_ctx_embeddings):
ctx_embeddings += ratio * ctx_embeddings_
return ctx_embeddings
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "clip_L")
qformer_cross_attention_freq = cfg.get("qformer_cross_attention_freq", 1)
qformer_num_query_token = cfg.get("qformer_num_query_token", 16)
qformer_train = cfg.get("qformer_train", False)
sd_train_text_encoder = cfg.get("sd_train_text_encoder", False)
sd_pretrained_model_name_or_path = cfg.get(
"sd_pretrained_model_name_or_path", "runwayml/stable-diffusion-v1-5"
)
controlnet_pretrained_model_name_or_path = cfg.get(
"controlnet_pretrained_model_name_or_path", None
)
vae_half_precision = cfg.get("vae_half_precision", False)
model = cls(
vit_model=vit_model,
qformer_cross_attention_freq=qformer_cross_attention_freq,
qformer_num_query_token=qformer_num_query_token,
qformer_train=qformer_train,
sd_train_text_encoder=sd_train_text_encoder,
sd_pretrained_model_name_or_path=sd_pretrained_model_name_or_path,
controlnet_pretrained_model_name_or_path=controlnet_pretrained_model_name_or_path,
vae_half_precision=vae_half_precision,
)
model.load_checkpoint_from_config(cfg)
return model
def load_checkpoint_from_dir(self, checkpoint_dir_or_url):
# if checkpoint_dir is a url, download it and untar it
if is_url(checkpoint_dir_or_url):
checkpoint_dir_or_url = download_and_untar(checkpoint_dir_or_url)
logging.info(f"Loading pretrained model from {checkpoint_dir_or_url}")
def load_state_dict(module, filename):
try:
state_dict = torch.load(
os.path.join(checkpoint_dir_or_url, filename), map_location="cpu"
)
msg = module.load_state_dict(state_dict, strict=False)
except FileNotFoundError:
logging.info("File not found, skip loading: {}".format(filename))
load_state_dict(self.proj_layer, "proj_layer/proj_weight.pt")
load_state_dict(self.blip, "blip_model/blip_weight.pt")
load_state_dict(self.unet, "unet/diffusion_pytorch_model.bin")
load_state_dict(self.vae, "vae/diffusion_pytorch_model.bin")
load_state_dict(self.text_encoder, "text_encoder/pytorch_model.bin")
try:
self.ctx_embeddings_cache.data = torch.load(
os.path.join(
checkpoint_dir_or_url, "ctx_embeddings_cache/ctx_embeddings_cache.pt"
),
map_location=self.device,
)
self._use_embeddings_cache = True
print("Loaded ctx_embeddings_cache from {}".format(checkpoint_dir_or_url))
except FileNotFoundError:
self._use_embeddings_cache = False
print("No ctx_embeddings_cache found in {}".format(checkpoint_dir_or_url))
def load_from_pretrained(self, url_or_filename):
checkpoint_dir = url_or_filename
self.load_checkpoint_from_dir(checkpoint_dir)
def load_checkpoint(self, url_or_filename):
"""
Used to load finetuned models.
"""
super().load_checkpoint(url_or_filename)
print("loading fine-tuned model from {}".format(url_or_filename))
self._use_embeddings_cache = True