File size: 4,941 Bytes
66a500b
db7de4c
 
db506ab
da28892
a180f58
d9c8156
db506ab
db7de4c
 
 
 
 
 
 
104c5a4
db7de4c
ee7865a
db7de4c
ee7865a
db7de4c
 
 
 
 
c7f0abc
ee7865a
 
 
36750d2
 
 
 
 
 
 
 
 
 
 
 
ee7865a
 
db7de4c
c7f0abc
db7de4c
 
 
c7f0abc
db7de4c
 
 
ee7865a
db7de4c
 
 
 
 
 
ee7865a
db7de4c
 
ee7865a
db7de4c
 
 
 
 
 
 
 
 
 
 
 
 
 
f5b81b8
db7de4c
da28892
db7de4c
 
 
ee7865a
db7de4c
 
 
 
ee7865a
db7de4c
 
 
 
ee7865a
db7de4c
 
ee7865a
db7de4c
 
da28892
db7de4c
ee7865a
 
db7de4c
ee7865a
 
db7de4c
 
 
 
 
 
 
 
 
 
ee7865a
80621e1
ee7865a
66a500b
ee7865a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
from diffusers import FluxPipeline
import huggingface_hub
from huggingface_hub import InferenceClient
import os

huggingface_hub.login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
# Initialize the Flux pipeline
def initialize_flux_pipeline():
    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16
    )
    pipe.enable_model_cpu_offload()
    pipe.load_lora_weights("EvanZhouDev/open-genmoji", weight_name="flux-dev.safetensors")
    return pipe

flux_pipeline = initialize_flux_pipeline()

# Initialize the language model client
llm_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=os.getenv("HUGGINGFACE_API_TOKEN"))

# Function to refine the prompt
def refine_prompt(original_prompt):
    messages = [
        {
            "role": "system",
            "content": (
                "You are helping create a prompt for a Emoji generation image model. An emoji must be easily "
                "interpreted when small so details must be exaggerated to be clear. Your goal is to use descriptions "
                "to achieve this.\n\nYou will receive a user description, and you must rephrase it to consist of "
                "short phrases separated by periods, adding detail to everything the user provides.\n\nAdd describe "
                "the color of all parts or components of the emoji. Unless otherwise specified by the user, do not "
                "describe people. Do not describe the background of the image. Your output should be in the format:\n\n"
                "```emoji of {description}. {addon phrases}. 3D lighting. no cast shadows.```\n\nThe description "
                "should be a 1 sentence of your interpretation of the emoji. Then, you may choose to add addon phrases."
                " You must use the following in the given scenarios:\n\n- \"cute.\": If generating anything that's not "
                "an object, and also not a human\n- \"enlarged head in cartoon style.\": ONLY animals\n- \"head is "
                "turned towards viewer.\": ONLY humans or animals\n- \"detailed texture.\": ONLY objects\n\nFurther "
                "addon phrases may be added to ensure the clarity of the emoji."
            ),
        },
        {"role": "user", "content": original_prompt},
    ]
    completion = llm_client.chat_completion(messages, max_tokens=100)
    refined = completion["choices"][0]["message"]["content"].strip()
    return refined

# Define the process function
def process(prompt, guidance_scale, num_inference_steps, height, width, seed):
    print(f"Original Prompt: {prompt}")

    # Refine the prompt
    try:
        refined_prompt = refine_prompt(prompt)
        print(f"Refined Prompt: {refined_prompt}")
    except Exception as e:
        return f"Error refining prompt: {str(e)}"

    # Set the random generator seed
    generator = torch.Generator(device="cuda").manual_seed(seed)

    try:
        # Generate the image
        output = flux_pipeline(
            prompt=refined_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            height=height,
            width=width,
            generator=generator,
        )
        image = output.images[0]
        return image
    except Exception as e:
        return f"Error generating image: {str(e)}"

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Flux Text-to-Image Generator with Prompt Refinement")

    # User inputs
    with gr.Row():
        prompt_input = gr.Textbox(label="Enter a Prompt", placeholder="Describe your image")
        guidance_scale_input = gr.Slider(
            label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.5, step=0.1
        )
    with gr.Row():
        num_inference_steps_input = gr.Slider(
            label="Inference Steps", minimum=1, maximum=100, value=50, step=1
        )
        seed_input = gr.Number(label="Seed", value=42, precision=0)
    with gr.Row():
        height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=768, step=64)
        width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1360, step=64)

    # Output components
    refined_prompt_output = gr.Textbox(label="Refined Prompt", interactive=False)
    image_output = gr.Image(label="Generated Image")

    # Button to generate the image
    generate_button = gr.Button("Generate Image")

    # Define button click behavior
    generate_button.click(
        fn=lambda prompt, *args: (refine_prompt(prompt), process(prompt, *args)),
        inputs=[
            prompt_input,
            guidance_scale_input,
            num_inference_steps_input,
            height_input,
            width_input,
            seed_input,
        ],
        outputs=[refined_prompt_output, image_output],
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(show_error=True)