File size: 4,452 Bytes
666f8bf
8bf4029
 
 
 
a8f57f8
e450df1
 
 
e5cbbd6
8bf4029
a8f57f8
8bf4029
 
a8f57f8
684e9f3
8bf4029
a8f57f8
8bf4029
a8f57f8
 
8bf4029
 
 
f611d13
 
d48e497
f611d13
 
 
 
 
 
 
 
e450df1
 
 
 
f611d13
e450df1
 
f611d13
e5cbbd6
 
 
 
 
8bf4029
a8f57f8
 
 
8bf4029
 
f611d13
8bf4029
 
a8f57f8
8bf4029
 
 
 
 
a8f57f8
 
 
 
 
8bf4029
 
 
 
 
 
 
 
a8f57f8
 
8bf4029
e450df1
a8f57f8
8bf4029
 
 
 
 
 
 
 
 
 
 
 
a8f57f8
8bf4029
 
a8f57f8
8bf4029
 
 
a8f57f8
8bf4029
 
 
 
f611d13
 
 
 
 
 
 
a8f57f8
 
f611d13
 
 
e450df1
 
 
a8f57f8
8bf4029
e450df1
 
 
 
f611d13
e5cbbd6
8bf4029
 
e450df1
f611d13
e450df1
8bf4029
55bf26f
e5cbbd6
 
 
f611d13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
import io
import base64
from PIL import Image
import json

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

@spaces.GPU()
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    # Gerar a imagem
    image = pipe(
        prompt=prompt, 
        width=width,
        height=height,
        num_inference_steps=num_inference_steps, 
        generator=generator,
        guidance_scale=0.0
    ).images[0]
    
    # Converter a imagem para Base64
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    # Retornar JSON com Base64 e seed
    return {"image_base64": f"data:image/png;base64,{img_str}", "seed": seed}

# Função para a API personalizada
def api_infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
    result = infer(prompt, seed, randomize_seed, width, height, num_inference_steps)
    return result  # Retorna diretamente o JSON

examples = [
    "a tiny astronaut hatching from an egg on the moon",
    "a cat holding a sign that says hello world",
    "an anime illustration of a wiener schnitzel",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 [schnell]
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
        """)
        
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)
        seed_output = gr.Number(label="Seed", show_label=True)
        
        with gr.Accordion("Advanced Settings", open=False):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=50,
                step=1,
                value=4,
            )
        
        gr.Examples(
            examples=examples,
            fn=infer,
            inputs=[prompt],
            outputs=[result, seed_output],
            cache_examples=True,
            cache_mode="lazy"
        )

    # Função para formatar a saída para a interface
    def format_output(prompt, seed, randomize_seed, width, height, num_inference_steps):
        output = infer(prompt, seed, randomize_seed, width, height, num_inference_steps)
        return output["image_base64"], output["seed"]

    # Interface Gradio
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=format_output,
        inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
        outputs=[result, seed_output]
    )

    # Endpoint personalizado para a API
    demo.queue(api_name="infer_api").launch()

demo.launch()