helblazer811's picture
"Orphan branch commit with a readme"
55866f4
"""
This baseline just returns heatmaps as the raw cross attentions.
"""
from concept_attention.flux.src.flux.sampling import prepare, unpack
import torch
import einops
import PIL
from concept_attention.image_generator import FluxGenerator
from concept_attention.segmentation import SegmentationAbstractClass, add_noise_to_image, encode_image
from concept_attention.utils import embed_concepts, linear_normalization
class RawOutputSpaceBaseline():
"""
This class implements the cross attention baseline.
"""
def __init__(
self,
model_name: str = "flux-schnell",
device: str = "cuda",
offload: bool = True,
generator = None
):
super(RawOutputSpaceBaseline, self).__init__()
# Load up the flux generator
if generator is None:
self.generator = FluxGenerator(
model_name=model_name,
device=device,
offload=offload,
)
else:
self.generator = generator
# Unpack the tokenizer
self.tokenizer = self.generator.t5.tokenizer
def __call__(
self,
prompt,
concepts,
seed=4,
num_steps=4,
timesteps=None,
layers=list(range(19)),
softmax=False,
height=1024,
width=1024,
guidance=0.0,
):
"""
Generate cross attention heatmap visualizations.
Args:
- prompt: str, the prompt to generate the visualizations for
- seed: int, the seed to use for the visualization
Returns:
- attention_maps: torch.Tensor, the attention maps for the prompt
- tokens: list[str], the tokens in the prompt
- image: torch.Tensor, the image generated by the
"""
if timesteps is None:
timesteps = list(range(num_steps))
if layers is None:
layers = list(range(19))
# Run the image generator
image, _, all_concept_heatmaps = self.generator.generate_image(
width=height,
height=width,
num_steps=num_steps,
guidance=guidance,
seed=seed,
prompt=prompt,
concepts=concepts
)
# Apply softmax
if softmax:
all_concept_heatmaps = torch.nn.functional.softmax(all_concept_heatmaps, dim=-2)
concept_heatmaps = all_concept_heatmaps[:, layers]
concept_heatmaps = einops.reduce(
concept_heatmaps,
"time layers batch concepts patches -> batch concepts patches",
reduction="mean"
)
# Convert to torch float32
concept_heatmaps = concept_heatmaps.to(torch.float32)
concept_heatmaps = einops.rearrange(
concept_heatmaps,
"batch concepts (h w) -> batch concepts h w",
h=64,
w=64
)
return concept_heatmaps, image
class RawOutputSpaceSegmentationModel(SegmentationAbstractClass):
def __init__(
self,
model_name: str = "flux-schnell",
device: str = "cuda",
offload: bool = True,
generator=None,
):
"""
Initialize the segmentation model.
"""
super(RawOutputSpaceSegmentationModel, self).__init__()
if generator is not None:
self.generator = generator
else:
# Load up the flux generator
self.generator = FluxGenerator(
model_name=model_name,
device=device,
offload=offload,
)
self.is_schnell = "schnell" in model_name
def segment_individual_image(
self,
image: PIL.Image.Image,
concepts: list[str],
caption: str,
device: str = "cuda",
offload: bool = False,
num_samples: int = 1,
num_steps: int = 4,
noise_timestep: int = 2,
seed: int = 4,
width: int = 1024,
height: int = 1024,
stop_after_multimodal_attentions: bool = True,
layers: list[int] = list(range(19)),
normalize_concepts=True,
softmax: bool = False,
joint_attention_kwargs=None,
**kwargs
):
"""
Takes a real image and generates a segmentation map.
"""
# Encode the image into the VAE latent space
encoded_image_without_noise = encode_image(
image,
self.generator.ae,
offload=offload,
device=device,
)
# Do N trials
all_concept_heatmaps = []
for i in range(num_samples):
# Add noise to image
encoded_image, timesteps = add_noise_to_image(
encoded_image_without_noise,
num_steps=num_steps,
noise_timestep=noise_timestep,
seed=seed + i,
width=width,
height=height,
device=device,
is_schnell=self.is_schnell,
)
# Now run the diffusion model once on the noisy image
# Encode the concept vectors
if offload:
self.generator.t5, self.generator.clip = self.generator.t5.to(device), self.generator.clip.to(device)
inp = prepare(t5=self.generator.t5, clip=self.generator.clip, img=encoded_image, prompt=caption)
concept_embeddings, concept_ids, concept_vec = embed_concepts(
self.generator.clip,
self.generator.t5,
concepts,
)
inp["concepts"] = concept_embeddings.to(encoded_image.device)
inp["concept_ids"] = concept_ids.to(encoded_image.device)
inp["concept_vec"] = concept_vec.to(encoded_image.device)
# offload TEs to CPU, load model to gpu
if offload:
self.generator.t5, self.generator.clip = self.generator.t5.cpu(), self.generator.clip.cpu()
torch.cuda.empty_cache()
self.generator.model = self.generator.model.to(device)
# Denoise the intermediate images
guidance_vec = torch.full((encoded_image.shape[0],), 0.0, device=encoded_image.device, dtype=encoded_image.dtype)
t_curr = timesteps[0]
t_prev = timesteps[1]
t_vec = torch.full((encoded_image.shape[0],), t_curr, dtype=encoded_image.dtype, device=encoded_image.device)
pred, _, concept_heatmaps = self.generator.model(
img=inp["img"],
img_ids=inp["img_ids"],
txt=inp["txt"],
txt_ids=inp["txt_ids"],
concepts=inp["concepts"],
concept_ids=inp["concept_ids"],
concept_vec=inp["concept_vec"],
y=inp["concept_vec"],
timesteps=t_vec,
guidance=guidance_vec,
stop_after_multimodal_attentions=stop_after_multimodal_attentions,
joint_attention_kwargs=joint_attention_kwargs,
)
all_concept_heatmaps.append(concept_heatmaps)
all_concept_heatmaps = torch.stack(all_concept_heatmaps, dim=0)
if not stop_after_multimodal_attentions:
img = inp["img"] + (t_prev - t_curr) * pred
# decode latents to pixel space
img = unpack(img.float(), height, width)
with torch.autocast(device_type=self.generator.device.type, dtype=torch.bfloat16):
img = self.generator.ae.decode(img)
if self.generator.offload:
self.generator.ae.decoder.cpu()
torch.cuda.empty_cache()
img = img.clamp(-1, 1)
img = einops.rearrange(img[0], "c h w -> h w c")
# reconstructed_image = PIL.Image.fromarray(img.cpu().byte().numpy())
reconstructed_image = PIL.Image.fromarray((127.5 * (img + 1.0)).cpu().byte().numpy())
else:
img = None
reconstructed_image = None
# Decode the image
if offload:
self.generator.model.cpu()
torch.cuda.empty_cache()
self.generator.ae.decoder.to(device)
# if layers is not None:
# # Pull out the layer index
# concept_vectors = concept_vectors[layers]
# image_vectors = image_vectors[layers]
# Apply linear normalization to concepts
# if normalize_concepts:
# concept_vectors = linear_normalization(concept_vectors, dim=-2)
# Apply softmax
if softmax:
all_concept_heatmaps = torch.nn.functional.softmax(all_concept_heatmaps, dim=-2)
concept_heatmaps = all_concept_heatmaps[:, layers]
concept_heatmaps = einops.reduce(
concept_heatmaps,
"samples layers batch concepts patches -> batch concepts patches",
reduction="mean"
)
# Convert to torch float32
concept_heatmaps = concept_heatmaps.to(torch.float32)
concept_heatmaps = einops.rearrange(
concept_heatmaps,
"batch concepts (h w) -> batch concepts h w",
h=64,
w=64
)
return concept_heatmaps, reconstructed_image