Spaces:
Running
Running
File size: 3,682 Bytes
66a500b db7de4c da28892 a180f58 d9c8156 db7de4c ee7865a db7de4c ee7865a db7de4c c7f0abc ee7865a db7de4c ee7865a db7de4c c7f0abc db7de4c c7f0abc db7de4c ee7865a db7de4c ee7865a db7de4c ee7865a db7de4c f5b81b8 db7de4c da28892 db7de4c ee7865a db7de4c ee7865a db7de4c ee7865a db7de4c ee7865a db7de4c da28892 db7de4c ee7865a db7de4c ee7865a db7de4c ee7865a 80621e1 ee7865a 66a500b ee7865a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import torch
from diffusers import FluxPipeline
from huggingface_hub import InferenceClient
import os
# Initialize the Flux pipeline
def initialize_flux_pipeline():
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.load_lora_weights("EvanZhouDev/open-genmoji")
return pipe
flux_pipeline = initialize_flux_pipeline()
# Initialize the language model client
llm_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=os.getenv("HUGGINGFACE_API_TOKEN"))
# Function to refine the prompt
def refine_prompt(original_prompt):
messages = [
{
"role": "system",
"content": (
"You are refining a user-provided description for generating images. The output should focus on clarity, "
"detail, and vivid descriptions. The format should be concise and effective for image generation."
),
},
{"role": "user", "content": original_prompt},
]
completion = llm_client.chat_completion(messages, max_tokens=100)
refined = completion["choices"][0]["message"]["content"].strip()
return refined
# Define the process function
def process(prompt, guidance_scale, num_inference_steps, height, width, seed):
print(f"Original Prompt: {prompt}")
# Refine the prompt
try:
refined_prompt = refine_prompt(prompt)
print(f"Refined Prompt: {refined_prompt}")
except Exception as e:
return f"Error refining prompt: {str(e)}"
# Set the random generator seed
generator = torch.Generator(device="cuda").manual_seed(seed)
try:
# Generate the image
output = flux_pipeline(
prompt=refined_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=generator,
)
image = output.images[0]
return image
except Exception as e:
return f"Error generating image: {str(e)}"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Flux Text-to-Image Generator with Prompt Refinement")
# User inputs
with gr.Row():
prompt_input = gr.Textbox(label="Enter a Prompt", placeholder="Describe your image")
guidance_scale_input = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.5, step=0.1
)
with gr.Row():
num_inference_steps_input = gr.Slider(
label="Inference Steps", minimum=1, maximum=100, value=50, step=1
)
seed_input = gr.Number(label="Seed", value=42, precision=0)
with gr.Row():
height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=768, step=64)
width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1360, step=64)
# Output components
refined_prompt_output = gr.Textbox(label="Refined Prompt", interactive=False)
image_output = gr.Image(label="Generated Image")
# Button to generate the image
generate_button = gr.Button("Generate Image")
# Define button click behavior
generate_button.click(
fn=lambda prompt, *args: (refine_prompt(prompt), process(prompt, *args)),
inputs=[
prompt_input,
guidance_scale_input,
num_inference_steps_input,
height_input,
width_input,
seed_input,
],
outputs=[refined_prompt_output, image_output],
)
# Launch the app
if __name__ == "__main__":
demo.launch(show_error=True)
|