File size: 3,682 Bytes
66a500b
db7de4c
 
da28892
a180f58
d9c8156
db7de4c
 
 
 
 
 
 
 
 
ee7865a
db7de4c
ee7865a
db7de4c
 
 
 
 
c7f0abc
ee7865a
 
 
db7de4c
 
ee7865a
 
db7de4c
c7f0abc
db7de4c
 
 
c7f0abc
db7de4c
 
 
ee7865a
db7de4c
 
 
 
 
 
ee7865a
db7de4c
 
ee7865a
db7de4c
 
 
 
 
 
 
 
 
 
 
 
 
 
f5b81b8
db7de4c
da28892
db7de4c
 
 
ee7865a
db7de4c
 
 
 
ee7865a
db7de4c
 
 
 
ee7865a
db7de4c
 
ee7865a
db7de4c
 
da28892
db7de4c
ee7865a
 
db7de4c
ee7865a
 
db7de4c
 
 
 
 
 
 
 
 
 
ee7865a
80621e1
ee7865a
66a500b
ee7865a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import torch
from diffusers import FluxPipeline
from huggingface_hub import InferenceClient
import os

# Initialize the Flux pipeline
def initialize_flux_pipeline():
    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16
    )
    pipe.enable_model_cpu_offload()
    pipe.load_lora_weights("EvanZhouDev/open-genmoji")
    return pipe

flux_pipeline = initialize_flux_pipeline()

# Initialize the language model client
llm_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=os.getenv("HUGGINGFACE_API_TOKEN"))

# Function to refine the prompt
def refine_prompt(original_prompt):
    messages = [
        {
            "role": "system",
            "content": (
                "You are refining a user-provided description for generating images. The output should focus on clarity, "
                "detail, and vivid descriptions. The format should be concise and effective for image generation."
            ),
        },
        {"role": "user", "content": original_prompt},
    ]
    completion = llm_client.chat_completion(messages, max_tokens=100)
    refined = completion["choices"][0]["message"]["content"].strip()
    return refined

# Define the process function
def process(prompt, guidance_scale, num_inference_steps, height, width, seed):
    print(f"Original Prompt: {prompt}")

    # Refine the prompt
    try:
        refined_prompt = refine_prompt(prompt)
        print(f"Refined Prompt: {refined_prompt}")
    except Exception as e:
        return f"Error refining prompt: {str(e)}"

    # Set the random generator seed
    generator = torch.Generator(device="cuda").manual_seed(seed)

    try:
        # Generate the image
        output = flux_pipeline(
            prompt=refined_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            height=height,
            width=width,
            generator=generator,
        )
        image = output.images[0]
        return image
    except Exception as e:
        return f"Error generating image: {str(e)}"

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Flux Text-to-Image Generator with Prompt Refinement")

    # User inputs
    with gr.Row():
        prompt_input = gr.Textbox(label="Enter a Prompt", placeholder="Describe your image")
        guidance_scale_input = gr.Slider(
            label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.5, step=0.1
        )
    with gr.Row():
        num_inference_steps_input = gr.Slider(
            label="Inference Steps", minimum=1, maximum=100, value=50, step=1
        )
        seed_input = gr.Number(label="Seed", value=42, precision=0)
    with gr.Row():
        height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=768, step=64)
        width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1360, step=64)

    # Output components
    refined_prompt_output = gr.Textbox(label="Refined Prompt", interactive=False)
    image_output = gr.Image(label="Generated Image")

    # Button to generate the image
    generate_button = gr.Button("Generate Image")

    # Define button click behavior
    generate_button.click(
        fn=lambda prompt, *args: (refine_prompt(prompt), process(prompt, *args)),
        inputs=[
            prompt_input,
            guidance_scale_input,
            num_inference_steps_input,
            height_input,
            width_input,
            seed_input,
        ],
        outputs=[refined_prompt_output, image_output],
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(show_error=True)