ig / app.py
Afrinetwork7's picture
Update app.py
8291039 verified
raw
history blame
5.01 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
import boto3
import os
from io import BytesIO
import time
# S3 Configuration
S3_BUCKET = "afri"
S3_REGION = "eu-west-3"
S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62"
S3_SECRET_ACCESS_KEY = "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"
# Set up S3 client
s3_client = boto3.client('s3',
region_name=S3_REGION,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
def save_image_to_s3(image):
# Convert PIL Image to bytes
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Generate a unique filename
filename = f"generated_image_{int(time.time())}.png"
# Upload to S3
s3_client.put_object(Bucket=S3_BUCKET,
Key=filename,
Body=img_byte_arr,
ContentType='image/png',
ACL='public-read')
# Generate a public URL
url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}"
return url
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale
).images[0]
# Save image to S3 and get URL
image_url = save_image_to_s3(image)
return image_url, seed
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cat holding a sign that says hello world",
"an anime illustration of a wiener schnitzel",
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 [dev]
12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Text(label="Image URL", show_label=True)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed]
)
demo.launch()