Spaces:
Paused
Paused
import gradio as gr | |
import asyncio | |
import aiohttp | |
import time | |
from datetime import datetime | |
import plotly.graph_objects as go | |
from typing import Dict, List | |
import os | |
from dotenv import load_dotenv | |
import json | |
from PIL import Image, ImageDraw, ImageFont | |
import uuid | |
import threading | |
# Load environment variables first | |
load_dotenv() | |
# Constants | |
API_BASE_URL = "https://api.wavespeed.ai/api/v2" | |
API_KEY = os.getenv("WAVESPEED_API_KEY") # Move API_KEY to global scope | |
if not API_KEY: | |
raise ValueError("WAVESPEED_API_KEY not found in environment variables") | |
# Rest of constants | |
BACKENDS = { | |
"flux-dev": { | |
"endpoint": f"{API_BASE_URL}/wavespeed-ai/flux-dev-ultra-fast", | |
"name": "Flux-dev", | |
"color": "#FF9800", | |
}, | |
"hidream-dev": { | |
"endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-dev", | |
"name": "HiDream-dev", | |
"color": "#2196F3", | |
}, | |
"hidream-full": { | |
"endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-full", | |
"name": "HiDream-full", | |
"color": "#4CAF50", | |
}, | |
} | |
class BackendStatus: | |
def __init__(self): | |
self.reset() | |
self.history: List[Dict] = [] | |
def reset(self): | |
self.status = "idle" | |
self.progress = 0 | |
self.start_time = None | |
self.end_time = None | |
def start(self): | |
self.status = "processing" | |
self.progress = 0 | |
self.start_time = time.time() | |
self.end_time = None | |
def complete(self): | |
self.status = "completed" | |
self.progress = 100 | |
self.end_time = time.time() | |
self.history.append({ | |
"timestamp": datetime.now(), | |
"duration": self.end_time - self.start_time | |
}) | |
def fail(self): | |
self.status = "failed" | |
self.end_time = time.time() | |
class SessionManager: | |
_instances = {} | |
_lock = threading.Lock() | |
def get_manager(cls, session_id=None): | |
if session_id is None: | |
session_id = str(uuid.uuid4()) | |
with cls._lock: | |
if session_id not in cls._instances: | |
cls._instances[session_id] = GenerationManager() | |
return session_id, cls._instances[session_id] | |
def cleanup_old_sessions(cls, max_age=3600): # 1 hour default | |
current_time = time.time() | |
with cls._lock: | |
to_remove = [] | |
for session_id, manager in cls._instances.items(): | |
if (hasattr(manager, "last_activity") | |
and current_time - manager.last_activity > max_age): | |
to_remove.append(session_id) | |
for session_id in to_remove: | |
del cls._instances[session_id] | |
class GenerationManager: | |
def __init__(self): | |
self.backend_statuses = { | |
backend: BackendStatus() | |
for backend in BACKENDS | |
} | |
self.last_activity = time.time() | |
def update_activity(self): | |
self.last_activity = time.time() | |
def get_performance_plot(self): | |
fig = go.Figure() | |
has_data = False | |
for backend, status in self.backend_statuses.items(): | |
durations = [h["duration"] for h in status.history] | |
if durations: | |
has_data = True | |
avg_duration = sum(durations) / len(durations) | |
# Use bar chart instead of box plot | |
fig.add_trace( | |
go.Bar( | |
y=[avg_duration], # Average duration | |
x=[BACKENDS[backend]["name"]], # Backend name | |
name=BACKENDS[backend]["name"], | |
marker_color=BACKENDS[backend]["color"], | |
text=[f"{avg_duration:.2f}s"], # Show time in seconds | |
textposition="auto", | |
width=[0.5], # Make bars narrower | |
)) | |
# Set a minimum y-axis range if we have data | |
if has_data: | |
max_duration = max([ | |
max([h["duration"] for h in status.history] or [0]) | |
for status in self.backend_statuses.values() | |
]) | |
# Add 20% padding to the top | |
y_max = max_duration * 1.2 | |
# Ensure the y-axis always starts at 0 | |
fig.update_yaxes(range=[0, y_max]) | |
fig.update_layout( | |
title="Average Generation Time", | |
yaxis_title="Seconds", | |
xaxis_title="", | |
showlegend=False, | |
template="simple_white", | |
height=400, # Increase height | |
margin=dict(l=50, r=50, t=50, b=50), # Add margins | |
font=dict(size=14), # Larger font | |
) | |
# Make sure we have a valid figure even if no data | |
if not has_data: | |
fig.add_annotation( | |
text="No timing data available yet", | |
xref="paper", | |
yref="paper", | |
x=0.5, | |
y=0.5, | |
showarrow=False, | |
font=dict(size=16), | |
) | |
return fig | |
async def submit_task(self, backend: str, prompt: str) -> str: | |
status = self.backend_statuses[backend] | |
status.start() | |
try: | |
url = BACKENDS[backend]["endpoint"] | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}", | |
} | |
payload = { | |
"prompt": prompt, | |
"enable_safety_checker": False, | |
"enable_base64_output": True, # Enable base64 output | |
"size": "1024*1024", | |
"seed": -1, | |
} | |
if backend == "flux-dev": | |
payload.update({ | |
"guidance_scale": 3.5, | |
"num_images": 1, | |
"num_inference_steps": 28, | |
"strength": 0.8, | |
}) | |
print(f"Submitting task to {backend}") | |
print(f"URL: {url}") | |
print(f"Payload: {json.dumps(payload, indent=2)}") | |
# Use aiohttp instead of requests for async | |
async with aiohttp.ClientSession() as session: | |
async with session.post(url, headers=headers, | |
json=payload) as response: | |
if response.status == 200: | |
result = await response.json() | |
request_id = result["data"]["id"] | |
print( | |
f"Task submitted successfully. Request ID: {request_id}" | |
) | |
return request_id | |
else: | |
text = await response.text() | |
raise Exception( | |
f"API error: {response.status}, {text}") | |
except Exception as e: | |
status.fail() | |
raise Exception(f"Failed to submit task: {str(e)}") | |
# Add this method to reset history | |
def reset_history(self): | |
"""Reset history for all backends""" | |
for status in self.backend_statuses.values(): | |
status.history = [] # Clear history data | |
return self | |
# Helper function to create error images as data URIs | |
def create_error_image(backend, error_message): | |
try: | |
import base64 | |
from io import BytesIO | |
# Create an in-memory image | |
img = Image.new("RGB", (512, 512), color="#ffdddd") | |
draw = ImageDraw.Draw(img) | |
try: | |
font = ImageFont.truetype("Arial", 20) | |
except: | |
font = ImageFont.load_default() | |
# Wrap and draw error message | |
words = error_message.split(" ") | |
lines = [] | |
line = "" | |
for word in words: | |
if len(line + word) < 40: | |
line += word + " " | |
else: | |
lines.append(line) | |
line = word + " " | |
if line: | |
lines.append(line) | |
y_position = 100 | |
for line in lines: | |
draw.text((50, y_position), line, fill="black", font=font) | |
y_position += 30 | |
# Save to a BytesIO object instead of a file | |
buffer = BytesIO() | |
img.save(buffer, format="PNG") | |
img_bytes = buffer.getvalue() | |
# Convert to base64 and return as data URI | |
return f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}" | |
except Exception as e: | |
print(f"Failed to create error image: {e}") | |
# Return a simple error message as fallback | |
return "Error: " + error_message | |
# Fix the poll_once function to accept a manager parameter | |
async def poll_once(manager, backend, request_id): | |
"""Poll once and return result if complete, otherwise None""" | |
headers = {"Authorization": f"Bearer {API_KEY}"} | |
url = f"{API_BASE_URL}/predictions/{request_id}/result" | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url, headers=headers) as response: | |
if response.status == 200: | |
result = await response.json() | |
data = result["data"] | |
current_status = data["status"] | |
if current_status == "completed": | |
# IMPORTANT: Update status BEFORE returning - using the passed manager | |
manager.backend_statuses[backend].complete() | |
manager.update_activity() | |
# Handle base64 output | |
output = data["outputs"][0] | |
# Check if it's a base64 string or URL | |
if isinstance(output, str) and output.startswith("http"): | |
# It's a URL - return as is | |
return output | |
else: | |
# It's base64 data - format it as a data URI if needed | |
try: | |
# Format as data URI for Gradio to display directly | |
if isinstance( | |
output, str | |
) and not output.startswith("data:image"): | |
# Convert raw base64 to data URI format | |
return f"data:image/png;base64,{output}" | |
else: | |
# Already in data URI format | |
return output | |
except Exception as e: | |
print(f"Error processing base64 image: {e}") | |
raise Exception( | |
f"Failed to process base64 image: {str(e)}") | |
elif current_status == "failed": | |
manager.backend_statuses[backend].fail() | |
manager.update_activity() | |
error = data.get("error", "Unknown error") | |
raise Exception(error) | |
# Still processing | |
return None | |
else: | |
raise Exception(f"Poll error: {response.status}") | |
# Use a state variable to store session ID | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
session_id = gr.State(None) # Add this to store session ID | |
gr.Markdown("# 🌊 WaveSpeed AI Image Generator") | |
# Add the introduction with link to WaveSpeedAI | |
gr.Markdown( | |
"[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation." | |
) | |
gr.Markdown( | |
"Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries." | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
input_text = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Type here...", | |
lines=3, | |
) | |
with gr.Column(scale=1): | |
generate_btn = gr.Button("Generate", variant="primary") | |
# Two status boxes - small (default) and big (during generation) | |
small_status_box = gr.Markdown("Ready to generate images", | |
elem_id="small-status") | |
# Big status box in its own row with styling | |
with gr.Row(elem_id="big-status-row"): | |
big_status_box = gr.Markdown("", | |
elem_id="big-status", | |
visible=False, | |
elem_classes="big-status-box") | |
with gr.Row(): | |
with gr.Column(): | |
draft_output = gr.Image(label="Flux-dev") | |
with gr.Column(): | |
quick_output = gr.Image(label="HiDream-dev") | |
with gr.Column(): | |
best_output = gr.Image(label="HiDream-full") | |
performance_plot = gr.Plot(label="Performance Metrics") | |
# Add custom CSS for the big status box | |
css = """ | |
#big-status-row { | |
margin: 20px 0; | |
} | |
#big-status { | |
font-size: 28px; /* Even larger font size */ | |
font-weight: bold; | |
padding: 30px; /* More padding */ | |
background-color: #0D47A1; /* Deeper blue background */ | |
color: white; /* White text */ | |
border-radius: 10px; | |
text-align: center; | |
margin: 0 auto; | |
box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2); /* Stronger shadow */ | |
animation: deep-breath 3s infinite; /* Slower, deeper breathing animation */ | |
width: 100%; /* Full width */ | |
max-width: 800px; /* Maximum width */ | |
transition: all 0.3s ease; /* Smooth transitions */ | |
border-left: 6px solid #64B5F6; /* Add a colored border */ | |
border-right: 6px solid #64B5F6; /* Add a colored border */ | |
} | |
/* Deeper breathing animation */ | |
@keyframes deep-breath { | |
0% { | |
opacity: 0.7; | |
transform: scale(0.98); | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); | |
} | |
50% { | |
opacity: 1; | |
transform: scale(1.01); | |
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.3); | |
} | |
100% { | |
opacity: 0.7; | |
transform: scale(0.98); | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); | |
} | |
} | |
""" | |
gr.HTML(f"<style>{css}</style>") | |
# Update the generation function to use session manager | |
async def generate_all_backends_with_status_boxes(prompt, | |
current_session_id): | |
"""Generate images with big status box during generation""" | |
# Get or create a session manager | |
session_id, manager = SessionManager.get_manager(current_session_id) | |
manager.update_activity() | |
# IMPORTANT: Reset history when starting a new generation | |
if prompt and prompt.strip() != "": | |
manager.reset_history() # Clear previous performance metrics | |
if not prompt or prompt.strip() == "": | |
# Handle empty prompt case | |
yield ( | |
"⚠️ Please enter a prompt first", | |
"⚠️ Please enter a prompt first", | |
gr.update(visible=True), | |
gr.update(visible=False), | |
None, | |
None, | |
None, | |
None, | |
session_id, # Return the session ID | |
) | |
return | |
# Status message | |
status_message = f"🔄 PROCESSING: '{prompt}'" | |
# Initial state - clear all images, show big status box | |
yield ( | |
status_message, | |
status_message, | |
gr.update(visible=True), | |
gr.update(visible=False), | |
None, | |
None, | |
None, | |
None, | |
session_id, # Return the session ID | |
) | |
# For production mode: | |
completed_backends = set() | |
results = {"flux-dev": None, "hidream-dev": None, "hidream-full": None} | |
try: | |
# Submit all tasks | |
request_ids = {} | |
for backend in BACKENDS: | |
try: | |
request_id = await manager.submit_task(backend, prompt) | |
request_ids[backend] = request_id | |
except Exception as e: | |
# Handle submission error | |
print(f"Error submitting task for {backend}: {e}") | |
results[backend] = create_error_image(backend, str(e)) | |
completed_backends.add(backend) | |
# Poll all backends until they complete | |
max_poll_attempts = 300 | |
poll_attempt = 0 | |
# Main polling loop | |
while len(completed_backends | |
) < 3 and poll_attempt < max_poll_attempts: | |
poll_attempt += 1 | |
# Poll each pending backend | |
for backend in list(BACKENDS.keys()): | |
if backend in completed_backends: | |
continue | |
try: | |
# Only do actual API calls every few attempts to reduce load | |
if poll_attempt % 2 == 0 or backend == "flux-dev": | |
# Use the session manager instead of global manager | |
result = await poll_once(manager, backend, | |
request_ids[backend]) | |
if result: # Backend completed | |
results[backend] = result | |
completed_backends.add(backend) | |
# Yield updated state when an image completes | |
yield ( | |
status_message, | |
status_message, | |
gr.update(visible=True), | |
gr.update(visible=False), | |
results["flux-dev"], | |
results["hidream-dev"], | |
results["hidream-full"], | |
(manager.get_performance_plot() | |
if any(completed_backends) else None), | |
session_id, | |
) | |
except Exception as e: | |
print(f"Error polling {backend}: {str(e)}") | |
# Wait between poll attempts | |
await asyncio.sleep(0.1) | |
# Final status | |
final_status = ("✅ All generations completed!" | |
if len(completed_backends) == 3 else | |
"⚠️ Some generations timed out") | |
# Final yield | |
yield ( | |
final_status, | |
final_status, | |
gr.update(visible=False), | |
gr.update(visible=True), | |
results["flux-dev"], | |
results["hidream-dev"], | |
results["hidream-full"], | |
manager.get_performance_plot(), | |
session_id, | |
) | |
except Exception as e: | |
# Error handling | |
error_message = f"❌ Error: {str(e)}" | |
yield ( | |
error_message, | |
error_message, | |
gr.update(visible=False), | |
gr.update(visible=True), | |
None, | |
None, | |
None, | |
None, | |
session_id, | |
) | |
# Schedule periodic cleanup of old sessions | |
def cleanup_task(): | |
SessionManager.cleanup_old_sessions() | |
# Schedule the next cleanup | |
threading.Timer(3600, cleanup_task).start() # Run every hour | |
# Start the cleanup task | |
cleanup_task() | |
# Update the click handler to include session_id | |
generate_btn.click( | |
fn=generate_all_backends_with_status_boxes, | |
inputs=[input_text, session_id], | |
outputs=[ | |
small_status_box, | |
big_status_box, | |
big_status_box, # visibility | |
small_status_box, # visibility | |
draft_output, | |
quick_output, | |
best_output, | |
performance_plot, | |
session_id, # Update the session ID | |
], | |
api_name="generate", | |
max_batch_size=10, # Process up to 10 requests at once | |
concurrency_limit=20, # Allow up to 20 concurrent requests | |
concurrency_id="generation", # Group concurrent requests under this ID | |
) | |
# Launch with increased max_threads | |
if __name__ == "__main__": | |
demo.queue(max_size=50).launch( | |
server_name="0.0.0.0", | |
max_threads=16, # Increase thread count for better concurrency | |
) | |