File size: 5,130 Bytes
666f8bf
8bf4029
 
 
 
a8f57f8
e450df1
 
 
e0bb6b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bf4029
a8f57f8
8bf4029
 
a8f57f8
684e9f3
8bf4029
a8f57f8
8bf4029
a8f57f8
 
e0bb6b6
 
8bf4029
 
 
f611d13
 
d48e497
f611d13
 
 
 
 
 
 
 
e450df1
 
 
 
f611d13
e450df1
f611d13
e0bb6b6
 
 
 
 
 
 
 
 
 
 
 
 
e5cbbd6
8bf4029
a8f57f8
 
 
8bf4029
 
f611d13
8bf4029
 
a8f57f8
8bf4029
 
 
 
 
a8f57f8
 
 
 
 
8bf4029
 
 
 
 
 
 
 
a8f57f8
 
8bf4029
e450df1
a8f57f8
8bf4029
 
 
 
 
 
 
 
 
 
 
 
a8f57f8
8bf4029
 
a8f57f8
8bf4029
 
 
a8f57f8
8bf4029
 
 
 
f611d13
 
 
 
 
 
 
a8f57f8
 
f611d13
 
 
e450df1
 
 
a8f57f8
8bf4029
e450df1
 
 
 
f611d13
8bf4029
 
e450df1
f611d13
e450df1
8bf4029
55bf26f
e0bb6b6
f611d13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
import io
import base64
from PIL import Image
import logging
from fastapi import FastAPI
from pydantic import BaseModel

# Configurar logging para depuração
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Inicializar FastAPI
app = FastAPI()

# Modelo para validação dos parâmetros da API
class ImageRequest(BaseModel):
    prompt: str
    seed: int = 42
    randomize_seed: bool = False
    width: int = 1024
    height: int = 1024
    num_inference_steps: int = 4

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

@spaces.GPU()
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    logger.info(f"Chamando infer com prompt={prompt}, seed={seed}, randomize_seed={randomize_seed}, width={width}, height={height}, num_inference_steps={num_inference_steps}")
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    # Gerar a imagem
    image = pipe(
        prompt=prompt, 
        width=width,
        height=height,
        num_inference_steps=num_inference_steps, 
        generator=generator,
        guidance_scale=0.0
    ).images[0]
    
    # Converter a imagem para Base64
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    return {"image_base64": f"data:image/png;base64,{img_str}", "seed": seed}

# Endpoint FastAPI
@app.post("/api/infer")
async def api_infer(request: ImageRequest):
    logger.info(f"Requisição API recebida: {request}")
    result = infer(
        prompt=request.prompt,
        seed=request.seed,
        randomize_seed=request.randomize_seed,
        width=request.width,
        height=request.height,
        num_inference_steps=request.num_inference_steps
    )
    return result

examples = [
    "a tiny astronaut hatching from an egg on the moon",
    "a cat holding a sign that says hello world",
    "an anime illustration of a wiener schnitzel",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 [schnell]
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
        """)
        
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)
        seed_output = gr.Number(label="Seed", show_label=True)
        
        with gr.Accordion("Advanced Settings", open=False):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=50,
                step=1,
                value=4,
            )
        
        gr.Examples(
            examples=examples,
            fn=infer,
            inputs=[prompt],
            outputs=[result, seed_output],
            cache_examples=True,
            cache_mode="lazy"
        )

    # Função para formatar a saída para a interface
    def format_output(prompt, seed, randomize_seed, width, height, num_inference_steps):
        output = infer(prompt, seed, randomize_seed, width, height, num_inference_steps)
        return output["image_base64"], output["seed"]

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=format_output,
        inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
        outputs=[result, seed_output]
    )

# Iniciar o Gradio (sem queue, pois usamos FastAPI para a API)
demo.launch()