ConceptAttention / concept_attention /concept_attention_pipeline.py
helblazer811's picture
"Orphan branch commit with a readme"
55866f4
raw
history blame
6.25 kB
"""
Wrapper pipeline for concept attention.
"""
from dataclasses import dataclass
import PIL
import numpy as np
import matplotlib.pyplot as plt
from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
from concept_attention.image_generator import FluxGenerator
@dataclass
class ConceptAttentionPipelineOutput():
image: PIL.Image.Image | np.ndarray
concept_heatmaps: list[PIL.Image.Image]
class ConceptAttentionFluxPipeline():
"""
This is an object that allows you to generate images with flux, and
'encode' images with flux.
"""
def __init__(
self,
model_name: str = "flux-schnell",
offload_model=False,
device="cuda:0"
):
self.model_name = model_name
self.offload_model = False
# Load the generator
self.flux_generator = FluxGenerator(
model_name=model_name,
offload=offload_model,
device=device
)
# Make a Raw Cross Attention Segmentation Model and Raw Output space segmentation model
self.cross_attention_segmentation_model = RawCrossAttentionSegmentationModel(
generator=self.flux_generator
)
self.output_space_segmentation_model = RawOutputSpaceSegmentationModel(
generator=self.flux_generator
)
self.raw_output_space_generator = RawOutputSpaceBaseline(
generator=self.flux_generator
)
self.raw_cross_attention_generator = RawCrossAttentionBaseline(
generator=self.flux_generator
)
def generate_image(
self,
prompt: str,
concepts: list[str],
width: int = 1024,
height: int = 1024,
return_cross_attention = False,
layer_indices = list(range(15, 19)),
return_pil_heatmaps = True,
seed: int = 0,
num_inference_steps: int = 4,
guidance: float = 0.0,
timesteps=None,
softmax: bool = True,
cmap="plasma"
) -> ConceptAttentionPipelineOutput:
"""
Generate an image with flux, given a list of concepts.
"""
assert return_cross_attention is False, "Not supported yet"
assert all([layer_index >= 0 and layer_index < 19 for layer_index in layer_indices]), "Invalid layer index"
assert height == width, "Height and width must be the same for now"
if timesteps is None:
timesteps = list(range(num_inference_steps))
# Run the raw output space object
concept_heatmaps, image = self.raw_output_space_generator(
prompt,
concepts,
seed=seed,
num_steps=num_inference_steps,
timesteps=timesteps,
layers=layer_indices,
softmax=softmax,
height=width,
width=width,
guidance=guidance,
)
# Convert to numpy
concept_heatmaps = concept_heatmaps.detach().cpu().numpy()[0]
# Convert the torch heatmaps to PIL images.
if return_pil_heatmaps:
# Convert to a matplotlib color scheme
colored_heatmaps = []
for concept_heatmap in concept_heatmaps:
concept_heatmap = (concept_heatmap - concept_heatmap.min()) / (concept_heatmap.max() - concept_heatmap.min())
colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
colored_heatmaps.append(rgb_image)
concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
return ConceptAttentionPipelineOutput(
image=image,
concept_heatmaps=concept_heatmaps
)
def encode_image(
self,
image: PIL.Image.Image,
concepts: list[str],
prompt: str = "", # Optional
width: int = 1024,
height: int = 1024,
return_cross_attention = False,
layer_indices = list(range(15, 19)),
num_samples: int = 1,
device: str = "cuda:0",
return_pil_heatmaps: bool = True,
seed: int = 0,
cmap="plasma"
) -> ConceptAttentionPipelineOutput:
"""
Encode an image with flux, given a list of concepts.
"""
assert return_cross_attention is False, "Not supported yet"
assert all([layer_index >= 0 and layer_index < 19 for layer_index in layer_indices]), "Invalid layer index"
assert height == width, "Height and width must be the same for now"
# Run the raw output space object
concept_heatmaps, _ = self.output_space_segmentation_model.segment_individual_image(
image=image,
concepts=concepts,
caption=prompt,
device=device,
softmax=True,
layers=layer_indices,
num_samples=num_samples,
height=height,
width=width
)
concept_heatmaps = concept_heatmaps.detach().cpu().numpy()
# Convert the torch heatmaps to PIL images.
if return_pil_heatmaps:
min_val = concept_heatmaps.min()
max_val = concept_heatmaps.max()
# Convert to a matplotlib color scheme
colored_heatmaps = []
for concept_heatmap in concept_heatmaps:
# concept_heatmap = (concept_heatmap - concept_heatmap.min()) / (concept_heatmap.max() - concept_heatmap.min())
concept_heatmap = (concept_heatmap - min_val) / (max_val - min_val)
colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
colored_heatmaps.append(rgb_image)
concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
return ConceptAttentionPipelineOutput(
image=image,
concept_heatmaps=concept_heatmaps
)