helblazer811's picture
"Orphan branch commit with a readme"
55866f4
raw
history blame
11.6 kB
"""
A wrapper around a flux model that generates segmentation masks for particular
concepts.
"""
from abc import ABC, abstractmethod
import PIL
import torch
import numpy as np
import einops
import PIL
from torchvision import transforms
import torchvision.transforms.functional as F
from concept_attention.flux.src.flux.sampling import get_noise, get_schedule, prepare, unpack
from concept_attention.image_generator import FluxGenerator
from concept_attention.utils import embed_concepts, linear_normalization
class SegmentationAbstractClass(ABC):
def segment_individual_image(
self,
image: PIL.Image.Image,
concepts: list[str],
caption: str,
**kwargs
):
"""
Segments an individual image
"""
pass
def __call__(
self,
images: PIL.Image.Image | list[PIL.Image.Image],
target_concepts: list[str],
concepts: list[str],
captions: list[str],
mean_value_threshold: bool = True,
joint_attention_kwargs=None,
apply_blur=False,
**kwargs
):
if not isinstance(images, list):
images = [images]
# Encode each image using the flux model
all_coefficients, reconstructed_images, all_masks = [], [], []
for index, image in enumerate(images):
coefficients, reconstructed_image = self.segment_individual_image(
image,
concepts,
captions[index],
joint_attention_kwargs=joint_attention_kwargs,
**kwargs
)
# Apply a blur to the coefficients
if apply_blur:
coefficients = F.gaussian_blur(coefficients.unsqueeze(0), kernel_size=3, sigma=1.0).squeeze()
# Threshold each coefficient to make a set of masks
mean_values = torch.mean(coefficients, dim=(1, 2), keepdim=True)
masks = coefficients > mean_values
# Check if there is a particular a target concept or not
if target_concepts is None:
# Return all masks
all_masks.append(masks)
all_coefficients.append(coefficients)
reconstructed_images.append(reconstructed_image)
else:
# Binarize the coefficients to generate a segmentation mask
target_concept_index = concepts.index(target_concepts[index])
if mean_value_threshold:
mean_value = coefficients[target_concept_index].mean()
mask = coefficients[target_concept_index] > mean_value
else:
mask = coefficients[target_concept_index] > 0.0
target_concept_coefficients = coefficients[target_concept_index]
mask = mask.cpu().numpy()
target_concept_coefficients = target_concept_coefficients.detach().cpu().numpy()
all_masks.append(mask)
all_coefficients.append(target_concept_coefficients)
reconstructed_images.append(reconstructed_image)
return all_masks, all_coefficients, reconstructed_images
def add_noise_to_image(
encoded_image,
num_steps=50,
noise_timestep=49,
seed=63,
width=1024,
height=1024,
device="cuda",
is_schnell=True,
):
# prepare input
x = get_noise(
1,
height,
width,
device=device,
dtype=torch.bfloat16,
seed=seed,
)
timesteps = get_schedule(
num_steps,
x.shape[-1] * x.shape[-2] // 4,
shift=(not is_schnell),
)
t = timesteps[noise_timestep]
timesteps = timesteps[noise_timestep:]
x = t * x + (1.0 - t) * encoded_image.to(x.dtype)
return x, timesteps
@torch.no_grad()
def encode_image(
image: PIL.Image.Image,
autoencoder: torch.nn.Module,
offload=True,
device="cuda",
height=1024,
width=1024,
):
"""
Encodes a PIL image to the VAE latent space and adds noise to it
"""
if isinstance(image, PIL.Image.Image):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: 2.0 * x - 1.0),
])
image = transform(image)
else:
transform = transforms.Compose([
transforms.Lambda(lambda x: 2.0 * x - 1.0),
])
image = transform(image)
# init_image = image.convert("RGB")
# init_image = np.array(image)
init_image = image
if isinstance(init_image, np.ndarray):
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
init_image = init_image.unsqueeze(0)
init_image = init_image.to(device)
init_image = torch.nn.functional.interpolate(init_image, (height, width))
if offload:
autoencoder.encoder.to(device)
init_image = autoencoder.encode(init_image.to())
if offload:
autoencoder = autoencoder.cpu()
torch.cuda.empty_cache()
return init_image
@torch.no_grad()
def generate_concept_basis_and_image_representation(
image: PIL.Image.Image,
caption: str,
concepts: list[str],
noise_timestep: int | list[int] =49,
layers=list(range(19)),
normalize_concepts=True,
num_steps=50,
seed=63,
model_name="flux-schnell",
offload=True,
device="cuda",
target_space="output",
height=1024,
width=1024,
generator=None,
stop_after_multimodal_attentions=False,
num_samples=1,
joint_attention_kwargs=None,
reduce_dims=True,
**kwargs
):
"""
Takes a real image and generates a set of concept and image vectors.
"""
if generator is None:
# Load up the model
generator = FluxGenerator(
model_name,
device,
offload=offload
)
else:
model_name = generator.model_name
# Encode the image into the VAE latent space
encoded_image_without_noise = encode_image(
image,
generator.ae,
offload=offload,
device=device,
)
# Do N trials
for i in range(num_samples):
# Add noise to image
encoded_image, timesteps = add_noise_to_image(
encoded_image_without_noise,
num_steps=num_steps,
noise_timestep=noise_timestep,
seed=seed + i,
width=width,
height=height,
device=device,
is_schnell=False,
)
# Now run the diffusion model once on the noisy image
# Encode the concept vectors
if offload:
generator.t5, generator.clip = generator.t5.to(device), generator.clip.to(device)
inp = prepare(t5=generator.t5, clip=generator.clip, img=encoded_image, prompt=caption)
concept_embeddings, concept_ids, concept_vec = embed_concepts(
generator.clip,
generator.t5,
concepts,
)
inp["concepts"] = concept_embeddings.to(encoded_image.device)
inp["concept_ids"] = concept_ids.to(encoded_image.device)
inp["concept_vec"] = concept_vec.to(encoded_image.device)
# offload TEs to CPU, load model to gpu
if offload:
generator.t5, generator.clip = generator.t5.cpu(), generator.clip.cpu()
torch.cuda.empty_cache()
generator.model = generator.model.to(device)
# Denoise the intermediate images
guidance_vec = torch.full((encoded_image.shape[0],), 0.0, device=encoded_image.device, dtype=encoded_image.dtype)
t_curr = timesteps[0]
t_prev = timesteps[1]
t_vec = torch.full((encoded_image.shape[0],), t_curr, dtype=encoded_image.dtype, device=encoded_image.device)
pred = generator.model(
img=inp["img"],
img_ids=inp["img_ids"],
txt=inp["txt"],
txt_ids=inp["txt_ids"],
concepts=inp["concepts"],
concept_ids=inp["concept_ids"],
concept_vec=inp["concept_vec"],
null_txt=inp["null_txt"],
null_txt_vec=inp["null_txt_vec"],
null_txt_ids=inp["null_txt_ids"],
y=inp["concept_vec"],
timesteps=t_vec,
guidance=guidance_vec,
stop_after_multimodal_attentions=stop_after_multimodal_attentions,
joint_attention_kwargs=joint_attention_kwargs
)
if not stop_after_multimodal_attentions:
if offload:
generator.model.cpu()
torch.cuda.empty_cache()
generator.ae.decoder.to(pred.device)
img = inp["img"] + (t_prev - t_curr) * pred
# decode latents to pixel space
img = unpack(img.float(), height, width)
with torch.autocast(device_type=generator.device.type, dtype=torch.bfloat16):
img = generator.ae.decode(img)
if generator.offload:
generator.ae.decoder.cpu()
torch.cuda.empty_cache()
img = img.clamp(-1, 1)
img = einops.rearrange(img[0], "c h w -> h w c")
# reconstructed_image = PIL.Image.fromarray(img.cpu().byte().numpy())
reconstructed_image = PIL.Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy())
else:
img = None
reconstructed_image = None
# Decode the image
if offload:
generator.model.cpu()
torch.cuda.empty_cache()
generator.ae.decoder.to(device)
# Pull out the concept basis and image queries
concept_vectors = []
image_vectors = []
for double_block in generator.model.double_blocks:
if target_space == "output":
image_vecs = torch.stack(
double_block.image_output_vectors
).squeeze(1)
concept_vecs = torch.stack(
double_block.concept_output_vectors
).squeeze(1)
elif target_space == "cross_attention":
image_vecs = torch.stack(
double_block.image_query_vectors
).squeeze(1)
concept_vecs = torch.stack(
double_block.concept_key_vectors
).squeeze(1)
# Clear out the layer (always same)
double_block.clear_cached_vectors()
# Add to list
concept_vectors.append(concept_vecs)
image_vectors.append(image_vecs)
# Stack layers
concept_vectors = torch.stack(concept_vectors).to(torch.float32)
image_vectors = torch.stack(image_vectors).to(torch.float32)
if layers is not None:
# Pull out the layer index
concept_vectors = concept_vectors[layers]
image_vectors = image_vectors[layers]
# Apply linear normalization to concepts
if normalize_concepts:
concept_vectors = linear_normalization(concept_vectors, dim=-2)
if reduce_dims:
if len(image_vectors.shape) == 4:
image_vectors = einops.rearrange(
image_vectors,
"layers time patches d -> patches (layers time d)",
)
concept_vectors = einops.rearrange(
concept_vectors,
"layers time concepts d -> concepts (layers time d)"
)
else:
image_vectors = einops.rearrange(
image_vectors,
"layers time heads patches d -> patches (layers time heads d)",
)
concept_vectors = einops.rearrange(
concept_vectors,
"layers time heads concepts d -> concepts (layers time heads d)"
)
return image_vectors, concept_vectors, reconstructed_image