File size: 3,516 Bytes
0dec378
 
 
 
 
 
 
 
a5a56d7
9655256
0dec378
 
 
9655256
 
a5a56d7
0dec378
 
9655256
0dec378
 
 
 
 
 
 
 
 
 
 
 
 
dbdd92a
 
 
 
9655256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01681c8
 
0dec378
 
 
 
 
 
 
9655256
01681c8
9655256
 
01681c8
0dec378
01681c8
0dec378
9655256
0dec378
7b4c302
 
0dec378
 
9655256
 
 
0dec378
9655256
 
 
 
c597ed5
9655256
 
 
 
 
01681c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import gradio as gr
import torch
import numpy as np
import random
from diffusers import FluxPipeline, FluxTransformer2DModel
import spaces
from translatepy import Translator

# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
translator = Translator()
HF_TOKEN = os.environ.get("HF_TOKEN", None)

# ์ƒ์ˆ˜
model = "black-forest-labs/FLUX.1-dev"
MAX_SEED = np.iinfo(np.int32).max

# CSS ๋ฐ JS ์„ค์ •
CSS = """
footer {
    visibility: hidden;
}
"""

JS = """function () {
  gradioURL = window.location.href
  if (!gradioURL.endsWith('?__theme=dark')) {
    window.location.replace(gradioURL + '?__theme=dark');
  }
}"""

# Initialize `pipe` to None globally
pipe = None

# ๋ชจ๋ธ ๋กœ๋“œ ์‹œ๋„
try:
    transformer = FluxTransformer2DModel.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=torch.bfloat16)
    if torch.cuda.is_available():
        pipe = FluxPipeline.from_pretrained(
            model,
            transformer=transformer,
            torch_dtype=torch.bfloat16).to("cuda")
    else:
        print("CUDA is not available. Check your GPU settings.")
except Exception as e:
    print(f"Failed to load the model: {e}")

# ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_image(prompt, width=1024, height=1024, scales=5, steps=4, seed=-1, nums=1, progress=gr.Progress(track_tqdm=True)):
    if pipe is None:
        print("Model is not loaded properly. Please check the logs for details.")
        return None, "Model not loaded."

    if seed == -1:
        seed = random.randint(0, MAX_SEED)
    seed = int(seed)
    text = str(translator.translate(prompt, 'English'))
    generator = torch.Generator().manual_seed(seed)

    try:
        images = pipe(prompt=text, height=height, width=width, guidance_scale=scales, num_inference_steps=steps, max_sequence_length=512, num_images_per_prompt=nums, generator=generator).images
    except Exception as e:
        print(f"Error generating image: {e}")
        return None, "Error during image generation."

    return images, seed

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ ๋ฐ ์‹คํ–‰
with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
    gr.HTML("<h1><center>Flux Labs</center></h1>")
    gr.HTML("<p><center>Model Now: <a href='https://huggingface.co/sayakpaul/FLUX.1-merged'>FLUX.1 Merged</a><br>๐Ÿ™‡โ€โ™‚๏ธFrequent model changes</center></p>")
    with gr.Row():
        with gr.Column(scale=4):
            img = gr.Gallery(label='flux Generated Image', columns=1, preview=True, height=600)
            prompt = gr.Textbox(label='Enter Your Prompt (Multi-Languages)', placeholder="Enter prompt...", scale=6)
            sendBtn = gr.Button(scale=1, variant='primary')
        with gr.Accordion("Advanced Options", open=True):
            width = gr.Slider(label="Width", minimum=512, maximum=1280, step=8, value=1024)
            height = gr.Slider(label="Height", minimum=512, maximum=1280, step=8, value=1024)
            scales = gr.Slider(label="Guidance", minimum=3.5, maximum=7, step=0.1, value=3.5)
            steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=4)
            seed = gr.Slider(label="Seeds", minimum=-1, maximum=MAX_SEED, step=1, value=0)
            nums = gr.Slider(label="Image Numbers", minimum=1, maximum=4, step=1, value=1)
    sendBtn.click(fn=generate_image, inputs=[prompt, width, height, scales, steps, seed, nums], outputs=[img, seed])
    prompt.submit(fn=generate_image, inputs=[prompt, width, height, scales, steps, seed, nums], outputs=[img, seed])

demo.queue().launch()