hidream-arena / app.py
chengzeyi's picture
Upload folder using huggingface_hub
ccb88d2 verified
raw
history blame
20.8 kB
import gradio as gr
import asyncio
import aiohttp
import time
from datetime import datetime
import plotly.graph_objects as go
from typing import Dict, List
import os
from dotenv import load_dotenv
import json
from PIL import Image, ImageDraw, ImageFont
import uuid
import threading
# Load environment variables first
load_dotenv()
# Constants
API_BASE_URL = "https://api.wavespeed.ai/api/v2"
API_KEY = os.getenv("WAVESPEED_API_KEY") # Move API_KEY to global scope
if not API_KEY:
raise ValueError("WAVESPEED_API_KEY not found in environment variables")
# Rest of constants
BACKENDS = {
"flux-dev": {
"endpoint": f"{API_BASE_URL}/wavespeed-ai/flux-dev-ultra-fast",
"name": "Flux-dev",
"color": "#FF9800",
},
"hidream-dev": {
"endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-dev",
"name": "HiDream-dev",
"color": "#2196F3",
},
"hidream-full": {
"endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-full",
"name": "HiDream-full",
"color": "#4CAF50",
},
}
class BackendStatus:
def __init__(self):
self.reset()
self.history: List[Dict] = []
def reset(self):
self.status = "idle"
self.progress = 0
self.start_time = None
self.end_time = None
def start(self):
self.status = "processing"
self.progress = 0
self.start_time = time.time()
self.end_time = None
def complete(self):
self.status = "completed"
self.progress = 100
self.end_time = time.time()
self.history.append({
"timestamp": datetime.now(),
"duration": self.end_time - self.start_time
})
def fail(self):
self.status = "failed"
self.end_time = time.time()
class SessionManager:
_instances = {}
_lock = threading.Lock()
@classmethod
def get_manager(cls, session_id=None):
if session_id is None:
session_id = str(uuid.uuid4())
with cls._lock:
if session_id not in cls._instances:
cls._instances[session_id] = GenerationManager()
return session_id, cls._instances[session_id]
@classmethod
def cleanup_old_sessions(cls, max_age=3600): # 1 hour default
current_time = time.time()
with cls._lock:
to_remove = []
for session_id, manager in cls._instances.items():
if (hasattr(manager, "last_activity")
and current_time - manager.last_activity > max_age):
to_remove.append(session_id)
for session_id in to_remove:
del cls._instances[session_id]
class GenerationManager:
def __init__(self):
self.backend_statuses = {
backend: BackendStatus()
for backend in BACKENDS
}
self.last_activity = time.time()
def update_activity(self):
self.last_activity = time.time()
def get_performance_plot(self):
fig = go.Figure()
has_data = False
for backend, status in self.backend_statuses.items():
durations = [h["duration"] for h in status.history]
if durations:
has_data = True
avg_duration = sum(durations) / len(durations)
# Use bar chart instead of box plot
fig.add_trace(
go.Bar(
y=[avg_duration], # Average duration
x=[BACKENDS[backend]["name"]], # Backend name
name=BACKENDS[backend]["name"],
marker_color=BACKENDS[backend]["color"],
text=[f"{avg_duration:.2f}s"], # Show time in seconds
textposition="auto",
width=[0.5], # Make bars narrower
))
# Set a minimum y-axis range if we have data
if has_data:
max_duration = max([
max([h["duration"] for h in status.history] or [0])
for status in self.backend_statuses.values()
])
# Add 20% padding to the top
y_max = max_duration * 1.2
# Ensure the y-axis always starts at 0
fig.update_yaxes(range=[0, y_max])
fig.update_layout(
title="Average Generation Time",
yaxis_title="Seconds",
xaxis_title="",
showlegend=False,
template="simple_white",
height=400, # Increase height
margin=dict(l=50, r=50, t=50, b=50), # Add margins
font=dict(size=14), # Larger font
)
# Make sure we have a valid figure even if no data
if not has_data:
fig.add_annotation(
text="No timing data available yet",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
)
return fig
async def submit_task(self, backend: str, prompt: str) -> str:
status = self.backend_statuses[backend]
status.start()
try:
url = BACKENDS[backend]["endpoint"]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}",
}
payload = {
"prompt": prompt,
"enable_safety_checker": False,
"enable_base64_output": True, # Enable base64 output
"size": "1024*1024",
"seed": -1,
}
if backend == "flux-dev":
payload.update({
"guidance_scale": 3.5,
"num_images": 1,
"num_inference_steps": 28,
"strength": 0.8,
})
print(f"Submitting task to {backend}")
print(f"URL: {url}")
print(f"Payload: {json.dumps(payload, indent=2)}")
# Use aiohttp instead of requests for async
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers,
json=payload) as response:
if response.status == 200:
result = await response.json()
request_id = result["data"]["id"]
print(
f"Task submitted successfully. Request ID: {request_id}"
)
return request_id
else:
text = await response.text()
raise Exception(
f"API error: {response.status}, {text}")
except Exception as e:
status.fail()
raise Exception(f"Failed to submit task: {str(e)}")
# Add this method to reset history
def reset_history(self):
"""Reset history for all backends"""
for status in self.backend_statuses.values():
status.history = [] # Clear history data
return self
# Helper function to create error images as data URIs
def create_error_image(backend, error_message):
try:
import base64
from io import BytesIO
# Create an in-memory image
img = Image.new("RGB", (512, 512), color="#ffdddd")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("Arial", 20)
except:
font = ImageFont.load_default()
# Wrap and draw error message
words = error_message.split(" ")
lines = []
line = ""
for word in words:
if len(line + word) < 40:
line += word + " "
else:
lines.append(line)
line = word + " "
if line:
lines.append(line)
y_position = 100
for line in lines:
draw.text((50, y_position), line, fill="black", font=font)
y_position += 30
# Save to a BytesIO object instead of a file
buffer = BytesIO()
img.save(buffer, format="PNG")
img_bytes = buffer.getvalue()
# Convert to base64 and return as data URI
return f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}"
except Exception as e:
print(f"Failed to create error image: {e}")
# Return a simple error message as fallback
return "Error: " + error_message
# Fix the poll_once function to accept a manager parameter
async def poll_once(manager, backend, request_id):
"""Poll once and return result if complete, otherwise None"""
headers = {"Authorization": f"Bearer {API_KEY}"}
url = f"{API_BASE_URL}/predictions/{request_id}/result"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status == 200:
result = await response.json()
data = result["data"]
current_status = data["status"]
if current_status == "completed":
# IMPORTANT: Update status BEFORE returning - using the passed manager
manager.backend_statuses[backend].complete()
manager.update_activity()
# Handle base64 output
output = data["outputs"][0]
# Check if it's a base64 string or URL
if isinstance(output, str) and output.startswith("http"):
# It's a URL - return as is
return output
else:
# It's base64 data - format it as a data URI if needed
try:
# Format as data URI for Gradio to display directly
if isinstance(
output, str
) and not output.startswith("data:image"):
# Convert raw base64 to data URI format
return f"data:image/png;base64,{output}"
else:
# Already in data URI format
return output
except Exception as e:
print(f"Error processing base64 image: {e}")
raise Exception(
f"Failed to process base64 image: {str(e)}")
elif current_status == "failed":
manager.backend_statuses[backend].fail()
manager.update_activity()
error = data.get("error", "Unknown error")
raise Exception(error)
# Still processing
return None
else:
raise Exception(f"Poll error: {response.status}")
# Use a state variable to store session ID
with gr.Blocks(theme=gr.themes.Soft()) as demo:
session_id = gr.State(None) # Add this to store session ID
gr.Markdown("# 🌊 WaveSpeed AI Image Generator")
# Add the introduction with link to WaveSpeedAI
gr.Markdown(
"[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation."
)
gr.Markdown(
"Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries."
)
with gr.Row():
with gr.Column(scale=3):
input_text = gr.Textbox(
label="Enter your prompt",
placeholder="Type here...",
lines=3,
)
with gr.Column(scale=1):
generate_btn = gr.Button("Generate", variant="primary")
# Two status boxes - small (default) and big (during generation)
small_status_box = gr.Markdown("Ready to generate images",
elem_id="small-status")
# Big status box in its own row with styling
with gr.Row(elem_id="big-status-row"):
big_status_box = gr.Markdown("",
elem_id="big-status",
visible=False,
elem_classes="big-status-box")
with gr.Row():
with gr.Column():
draft_output = gr.Image(label="Flux-dev")
with gr.Column():
quick_output = gr.Image(label="HiDream-dev")
with gr.Column():
best_output = gr.Image(label="HiDream-full")
performance_plot = gr.Plot(label="Performance Metrics")
# Add custom CSS for the big status box
css = """
#big-status-row {
margin: 20px 0;
}
#big-status {
font-size: 28px; /* Even larger font size */
font-weight: bold;
padding: 30px; /* More padding */
background-color: #0D47A1; /* Deeper blue background */
color: white; /* White text */
border-radius: 10px;
text-align: center;
margin: 0 auto;
box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2); /* Stronger shadow */
animation: deep-breath 3s infinite; /* Slower, deeper breathing animation */
width: 100%; /* Full width */
max-width: 800px; /* Maximum width */
transition: all 0.3s ease; /* Smooth transitions */
border-left: 6px solid #64B5F6; /* Add a colored border */
border-right: 6px solid #64B5F6; /* Add a colored border */
}
/* Deeper breathing animation */
@keyframes deep-breath {
0% {
opacity: 0.7;
transform: scale(0.98);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
}
50% {
opacity: 1;
transform: scale(1.01);
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.3);
}
100% {
opacity: 0.7;
transform: scale(0.98);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
}
}
"""
gr.HTML(f"<style>{css}</style>")
# Update the generation function to use session manager
async def generate_all_backends_with_status_boxes(prompt,
current_session_id):
"""Generate images with big status box during generation"""
# Get or create a session manager
session_id, manager = SessionManager.get_manager(current_session_id)
manager.update_activity()
# IMPORTANT: Reset history when starting a new generation
if prompt and prompt.strip() != "":
manager.reset_history() # Clear previous performance metrics
if not prompt or prompt.strip() == "":
# Handle empty prompt case
yield (
"⚠️ Please enter a prompt first",
"⚠️ Please enter a prompt first",
gr.update(visible=True),
gr.update(visible=False),
None,
None,
None,
None,
session_id, # Return the session ID
)
return
# Status message
status_message = f"🔄 PROCESSING: '{prompt}'"
# Initial state - clear all images, show big status box
yield (
status_message,
status_message,
gr.update(visible=True),
gr.update(visible=False),
None,
None,
None,
None,
session_id, # Return the session ID
)
# For production mode:
completed_backends = set()
results = {"flux-dev": None, "hidream-dev": None, "hidream-full": None}
try:
# Submit all tasks
request_ids = {}
for backend in BACKENDS:
try:
request_id = await manager.submit_task(backend, prompt)
request_ids[backend] = request_id
except Exception as e:
# Handle submission error
print(f"Error submitting task for {backend}: {e}")
results[backend] = create_error_image(backend, str(e))
completed_backends.add(backend)
# Poll all backends until they complete
max_poll_attempts = 300
poll_attempt = 0
# Main polling loop
while len(completed_backends
) < 3 and poll_attempt < max_poll_attempts:
poll_attempt += 1
# Poll each pending backend
for backend in list(BACKENDS.keys()):
if backend in completed_backends:
continue
try:
# Only do actual API calls every few attempts to reduce load
if poll_attempt % 2 == 0 or backend == "flux-dev":
# Use the session manager instead of global manager
result = await poll_once(manager, backend,
request_ids[backend])
if result: # Backend completed
results[backend] = result
completed_backends.add(backend)
# Yield updated state when an image completes
yield (
status_message,
status_message,
gr.update(visible=True),
gr.update(visible=False),
results["flux-dev"],
results["hidream-dev"],
results["hidream-full"],
(manager.get_performance_plot()
if any(completed_backends) else None),
session_id,
)
except Exception as e:
print(f"Error polling {backend}: {str(e)}")
# Wait between poll attempts
await asyncio.sleep(0.1)
# Final status
final_status = ("✅ All generations completed!"
if len(completed_backends) == 3 else
"⚠️ Some generations timed out")
# Final yield
yield (
final_status,
final_status,
gr.update(visible=False),
gr.update(visible=True),
results["flux-dev"],
results["hidream-dev"],
results["hidream-full"],
manager.get_performance_plot(),
session_id,
)
except Exception as e:
# Error handling
error_message = f"❌ Error: {str(e)}"
yield (
error_message,
error_message,
gr.update(visible=False),
gr.update(visible=True),
None,
None,
None,
None,
session_id,
)
# Schedule periodic cleanup of old sessions
def cleanup_task():
SessionManager.cleanup_old_sessions()
# Schedule the next cleanup
threading.Timer(3600, cleanup_task).start() # Run every hour
# Start the cleanup task
cleanup_task()
# Update the click handler to include session_id
generate_btn.click(
fn=generate_all_backends_with_status_boxes,
inputs=[input_text, session_id],
outputs=[
small_status_box,
big_status_box,
big_status_box, # visibility
small_status_box, # visibility
draft_output,
quick_output,
best_output,
performance_plot,
session_id, # Update the session ID
],
api_name="generate",
max_batch_size=10, # Process up to 10 requests at once
concurrency_limit=20, # Allow up to 20 concurrent requests
concurrency_id="generation", # Group concurrent requests under this ID
)
# Launch with increased max_threads
if __name__ == "__main__":
demo.queue(max_size=50).launch(
server_name="0.0.0.0",
max_threads=16, # Increase thread count for better concurrency
)