import base64 import io import spaces import gradio as gr from PIL import Image from concept_attention import ConceptAttentionFluxPipeline concept_attention_default_args = { "model_name": "flux-schnell", "device": "cuda", "layer_indices": list(range(10, 19)), "timesteps": list(range(4)), "num_samples": 4, "num_inference_steps": 4 } IMG_SIZE = 250 EXAMPLES = [ [ "A fluffy cat sitting on a windowsill", # prompt "cat.jpg", # image "fur, whiskers, eyes", # words 42, # seed ], # ["Mountain landscape with lake", "cat.jpg", "sky, trees, water", 123], # ["Portrait of a young woman", "monkey.png", "face, hair, eyes", 456], ] pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda") @spaces.GPU(duration=60) def process_inputs(prompt, input_image, word_list, seed): prompt = prompt.strip() if not word_list.strip(): return None, "Please enter comma-separated words" concepts = [w.strip() for w in word_list.split(",")] if input_image is not None: input_image = Image.fromarray(input_image) input_image = input_image.convert("RGB") input_image = input_image.resize((1024, 1024)) pipeline_output = pipeline.encode_image( image=input_image, concepts=concepts, prompt=prompt, width=1024, height=1024, seed=seed, num_samples=concept_attention_default_args["num_samples"] ) else: pipeline_output = pipeline.generate_image( prompt=prompt, concepts=concepts, width=1024, height=1024, seed=seed, timesteps=concept_attention_default_args["timesteps"], num_inference_steps=concept_attention_default_args["num_inference_steps"], ) output_image = pipeline_output.image concept_heatmaps = pipeline_output.concept_heatmaps html_elements = [] for concept, heatmap in zip(concepts, concept_heatmaps): img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() html = f"""

{concept}

""" html_elements.append(html) combined_html = "
" + "".join(html_elements) + "
" return output_image, combined_html with gr.Blocks( css=""" .container { max-width: 1200px; margin: 0 auto; padding: 20px; } .title { text-align: center; margin-bottom: 10px; } .authors { text-align: center; margin-bottom: 20px; } .affiliations { text-align: center; color: #666; margin-bottom: 40px; } .content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; } .section { border: 2px solid #ddd; border-radius: 10px; padding: 20px; } """ ) as demo: with gr.Column(elem_classes="container"): gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title") gr.Markdown("**Alec Helbling**¹, **Tuna Meral**², **Ben Hoover**¹³, **Pinar Yanardag**², **Duen Horng (Polo) Chau**¹", elem_classes="authors") gr.Markdown("¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations") with gr.Row(elem_classes="content"): with gr.Column(elem_classes="section"): gr.Markdown("### Input") prompt = gr.Textbox(label="Enter your prompt") words = gr.Textbox(label="Enter words (comma-separated)") seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) gr.HTML("

Or

") image_input = gr.Image(type="numpy", label="Upload image (optional)") with gr.Column(elem_classes="section"): gr.Markdown("### Output") output_image = gr.Image(type="numpy", label="Output image") with gr.Row(): submit_btn = gr.Button("Process") with gr.Row(elem_classes="section"): saliency_display = gr.HTML(label="Saliency Maps") submit_btn.click( fn=process_inputs, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display] ) gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False) if __name__ == "__main__": demo.launch( share=True, server_name="0.0.0.0", inbrowser=True, # share=False, server_port=6754, quiet=True, max_threads=1 )