import spaces import gradio as gr from PIL import Image import math import io import base64 import subprocess import os from concept_attention import ConceptAttentionFluxPipeline IMG_SIZE = 210 COLUMNS = 5 def update_default_concepts(prompt): default_concepts = { "A dog by a tree": ["dog", "grass", "tree", "background"], "A man on the beach": ["man", "dirt", "ocean", "sky"], "A hot air balloon": ["balloon", "sky", "water", "tree"] } return gr.update(value=default_concepts.get(prompt, [])) pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda:2", offload_model=True) def convert_pil_to_bytes(img): img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return img_str @spaces.GPU(duration=60) def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index): if not prompt: raise gr.exceptions.InputError("prompt", "Please enter a prompt") if not prompt.strip(): raise gr.exceptions.InputError("prompt", "Please enter a prompt") prompt = prompt.strip() if len(concepts) == 0: raise gr.exceptions.InputError("words", "Please enter at least 1 concept") if len(concepts) > 9: raise gr.exceptions.InputError("words", "Please enter at most 9 concepts") pipeline_output = pipeline.generate_image( prompt=prompt, concepts=concepts, width=1024, height=1024, seed=seed, timesteps=list(range(timestep_start_index, 4)), num_inference_steps=4, layer_indices=list(range(layer_start_index, 19)), softmax=True if len(concepts) > 1 else False ) output_image = pipeline_output.image output_space_heatmaps = pipeline_output.concept_heatmaps output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps] output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] cross_attention_heatmaps = pipeline_output.cross_attention_maps cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps] cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] return output_image, \ gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \ gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels)) with gr.Blocks( css=""" .container { max-width: 1300px; margin: 0 auto; padding: 20px; } .application { max-width: 1200px; } .generated-image { display: flex; align-items: center; justify-content: center; height: 100%; /* Ensures full height */ } .input { height: 47px; } .input-column { flex-direction: column; gap: 0px; height: 100%; } .input-column-label {} .gallery { height: 200px; } .run-button-column { width: 100px !important; } .gallery-container { scrollbar-width: thin; scrollbar-color: grey black; } /* Show only on screens wider than 768px (adjust as needed) @media (min-width: 1024px) { .svg-container { min-width: 150px; width: 200px; padding-top: 540px; } } @media (min-width: 1280px) { .svg-container { min-width: 200px; width: 300px; padding-top: 420px; } } @media (min-width: 1530px) { .svg-container { min-width: 200px; width: 300px; padding-top: 400px; } } */ @media (min-width: 1024px) { .svg-container { min-width: 250px; } #concept-attention-callout-svg { width: 250px; } } @media (max-width: 1024px) { .svg-container { display: none !important; } #concept-attention-callout-svg { display: none; } } .header { display: flex; flex-direction: column; } #title { font-size: 4.4em; color: #F3B13E; text-align: center; margin: 5px; } #subtitle { font-size: 3.0em; color: #FAE2BA; text-align: center; margin: 5px; } #abstract { text-align: center; font-size: 2.0em; color:rgb(219, 219, 219); margin: 5px; margin-top: 10px; } #links { text-align: center; font-size: 2.0em; margin: 5px; } #links a { color: #93B7E9; text-decoration: none; } .svg-container { display: flex; justify-content: center; align-items: center; } .caption-label { font-size: 1.15em; } .gallery label { font-size: 1.15em; } """ ) as demo: # with gr.Column(elem_classes="container"): with gr.Row(elem_classes="container", scale=8): with gr.Column(elem_classes="application-content", scale=10): with gr.Row(scale=3, elem_classes="header"): gr.HTML("""