import gradio as gr import numpy as np import random import torch from diffusers import DiffusionPipeline from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast import boto3 from io import BytesIO import time import os # S3 Configuration S3_BUCKET = "afri" S3_REGION = "eu-west-3" S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62" S3_SECRET_ACCESS_KEY = "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" # Set up S3 client s3_client = boto3.client('s3', region_name=S3_REGION, aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY) dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 def save_image_to_s3(image): img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() filename = f"generated_image_{int(time.time())}.png" s3_client.put_object(Bucket=S3_BUCKET, Key=filename, Body=img_byte_arr, ContentType='image/png') url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}" return url def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) image = pipe( prompt=prompt, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale ).images[0] image_url = save_image_to_s3(image) return image_url, seed examples = [ "a tiny astronaut hatching from an egg on the moon", "a cat holding a sign that says hello world", "an anime illustration of a wiener schnitzel", ] css=""" #col-container { margin: 0 auto; max-width: 520px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f"""# FLUX.1 [dev] 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)] """) with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) result = gr.Text(label="Image URL", show_label=True) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=28, ) gr.Examples( examples=examples, fn=infer, inputs=[prompt], outputs=[result, seed], cache_examples="lazy" ) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs=[result, seed] ) demo.launch(share=True)