import gradio as gr import numpy as np import random import spaces import torch from diffusers import DiffusionPipeline import io import base64 from PIL import Image import logging from fastapi import FastAPI from pydantic import BaseModel # Configurar logging para depuração logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Inicializar FastAPI app = FastAPI() # Modelo para validação dos parâmetros da API class ImageRequest(BaseModel): prompt: str seed: int = 42 randomize_seed: bool = False width: int = 1024 height: int = 1024 num_inference_steps: int = 4 dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 @spaces.GPU() def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)): logger.info(f"Chamando infer com prompt={prompt}, seed={seed}, randomize_seed={randomize_seed}, width={width}, height={height}, num_inference_steps={num_inference_steps}") if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) # Gerar a imagem image = pipe( prompt=prompt, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=0.0 ).images[0] # Converter a imagem para Base64 buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return {"image_base64": f"data:image/png;base64,{img_str}", "seed": seed} # Endpoint FastAPI @app.post("/api/infer") async def api_infer(request: ImageRequest): logger.info(f"Requisição API recebida: {request}") result = infer( prompt=request.prompt, seed=request.seed, randomize_seed=request.randomize_seed, width=request.width, height=request.height, num_inference_steps=request.num_inference_steps ) return result examples = [ "a tiny astronaut hatching from an egg on the moon", "a cat holding a sign that says hello world", "an anime illustration of a wiener schnitzel", ] css = """ #col-container { margin: 0 auto; max-width: 520px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f"""# FLUX.1 [schnell] 12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)] """) with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) result = gr.Image(label="Result", show_label=False) seed_output = gr.Number(label="Seed", show_label=True) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=4, ) gr.Examples( examples=examples, fn=infer, inputs=[prompt], outputs=[result, seed_output], cache_examples=True, cache_mode="lazy" ) # Função para formatar a saída para a interface def format_output(prompt, seed, randomize_seed, width, height, num_inference_steps): output = infer(prompt, seed, randomize_seed, width, height, num_inference_steps) return output["image_base64"], output["seed"] gr.on( triggers=[run_button.click, prompt.submit], fn=format_output, inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps], outputs=[result, seed_output] ) # Iniciar o Gradio (sem queue, pois usamos FastAPI para a API) demo.launch()