Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import io | |
import random | |
import os | |
import time | |
from PIL import Image, UnidentifiedImageError | |
from deep_translator import GoogleTranslator | |
import json | |
import uuid | |
from urllib.parse import quote | |
import traceback | |
from openai import OpenAI, RateLimitError, APIConnectionError, AuthenticationError | |
# Project by Nymbo | |
# --- Constants --- | |
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" | |
API_TOKEN = os.getenv("HF_READ_TOKEN") | |
if not API_TOKEN: | |
print("WARNING: HF_READ_TOKEN environment variable not set. API calls may fail.") | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {} | |
timeout = 100 # seconds for API call timeout | |
IMAGE_DIR = "temp_generated_images" # Directory to store temporary images | |
ARINTELI_REDIRECT_BASE = "https://arinteli.com/app/" # Your redirector URL | |
# --- Ensure temporary directory exists --- | |
try: | |
os.makedirs(IMAGE_DIR, exist_ok=True) | |
print(f"Confirmed temporary image directory exists: {IMAGE_DIR}") | |
except OSError as e: | |
print(f"ERROR: Could not create directory {IMAGE_DIR}: {e}") | |
# This is critical, so raise an error to prevent app start if dir fails | |
raise gr.Error(f"Fatal Error: Cannot create temporary image directory: {e}") | |
# --- Get Absolute Path for allowed_paths --- | |
# This needs to be done *before* calling launch() | |
absolute_image_dir = os.path.abspath(IMAGE_DIR) | |
print(f"Absolute path for allowed_paths: {absolute_image_dir}") | |
# --- OpenAI Client --- | |
try: | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
if not openai_api_key: | |
print("WARNING: OPENAI_API_KEY environment variable not set or empty. Moderation will be skipped.") | |
openai_client = None | |
else: | |
openai_client = OpenAI() | |
print("OpenAI client initialized successfully.") | |
except Exception as e: | |
print(f"ERROR: Failed to initialize OpenAI client: {e}. Moderation will be skipped.") | |
openai_client = None | |
# Function to query the API and return the generated image and download link | |
def query(prompt, negative_prompt="", steps=4, cfg_scale=0, seed=-1, width=1024, height=1024): | |
# Basic Input Validation | |
if not prompt or not prompt.strip(): | |
print("WARNING: Empty prompt received.\n") # Add newline for separation | |
return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>" | |
# Store original prompt for logging | |
original_prompt = prompt | |
# Translation | |
translated_prompt = prompt # Start with original | |
try: | |
translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt) | |
except Exception as e: | |
print(f"WARNING: Translation failed. Using original prompt.") | |
print(f" Error: {e}") | |
print(f" Prompt: '{original_prompt}'\n") # Add newline | |
translated_prompt = prompt # Ensure it's the original if translation fails | |
# --- OpenAI Moderation Check --- | |
if openai_client: | |
try: | |
mod_response = openai_client.moderations.create( | |
model="omni-moderation-latest", | |
input=translated_prompt | |
) | |
result = mod_response.results[0] | |
if result.categories.sexual_minors: | |
print("BLOCKED:") | |
print(f" Reason: sexual/minors") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print("") # Add newline | |
return None, "<p style='color: red; text-align: center;'>Prompt violates safety guidelines. Generation blocked.</p>" | |
except AuthenticationError: | |
print("BLOCKED:") | |
print(f" Reason: OpenAI Auth Error") | |
print(f" Prompt: '{original_prompt}'\n") # Add newline | |
return None, "<p style='color: red; text-align: center;'>Safety check failed. Generation blocked.</p>" | |
except RateLimitError: | |
print("BLOCKED:") | |
print(f" Reason: OpenAI Rate Limit") | |
print(f" Prompt: '{original_prompt}'\n") # Add newline | |
return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>" | |
except APIConnectionError as e: | |
print("BLOCKED:") | |
print(f" Reason: OpenAI Connection Error") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Error: {e}\n") # Add newline | |
return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>" | |
except Exception as e: | |
print("BLOCKED:") | |
print(f" Reason: OpenAI Unexpected Error") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Error: {e}\n") # Add newline | |
traceback.print_exc() | |
return None, "<p style='color: red; text-align: center;'>An unexpected error occurred during safety check. Generation blocked.</p>" | |
else: | |
print(f"WARNING: OpenAI client not available. Skipping moderation.") | |
print(f" Prompt: '{original_prompt}'") # Log original prompt even if skipping | |
if translated_prompt != original_prompt: | |
print(f" (Would use translated: '{translated_prompt}')") | |
print("") # Add newline | |
# --- Proceed with Generation --- | |
# Add suffix (Ensure consistency with gradio1.py if desired, or keep model-specific suffixes) | |
final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect." | |
# Prepare payload (Adjust parameters based on FLUX.1-schnell API requirements if different) | |
payload = { | |
"inputs": final_prompt, | |
"parameters": { | |
"negative_prompt": negative_prompt, | |
"num_inference_steps": steps, | |
"guidance_scale": cfg_scale, | |
"seed": seed if seed != -1 else random.randint(1, 1000000000), | |
"width": width, | |
"height": height, | |
# Add any other specific parameters for FLUX.1-schnell if needed | |
} | |
} | |
# API Call Section | |
try: | |
if not headers: | |
print("FAILED:") | |
print(f" Reason: HF Token Missing") | |
print(f" Prompt: '{original_prompt}'\n") # Add newline | |
return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>" | |
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout) | |
response.raise_for_status() | |
image_bytes = response.content | |
if not image_bytes or len(image_bytes) < 100: | |
print("FAILED:") | |
print(f" Reason: Invalid Image Data (Empty/Small)") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Length: {len(image_bytes)}\n") # Add newline | |
return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>" | |
try: | |
image = Image.open(io.BytesIO(image_bytes)) | |
except UnidentifiedImageError as img_err: | |
print("FAILED:") | |
print(f" Reason: Image Processing Error") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Error: {img_err}\n") # Add newline | |
return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>" | |
# --- Save image and create download link --- | |
filename = f"{int(time.time())}_{uuid.uuid4().hex[:8]}.png" | |
save_path = os.path.join(IMAGE_DIR, filename) | |
absolute_save_path = os.path.abspath(save_path) | |
try: | |
image.save(save_path, "PNG") | |
if not os.path.exists(save_path) or os.path.getsize(save_path) < 100: | |
print("FAILED:") | |
print(f" Reason: Image Save Verification Error") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Path: '{save_path}'\n") # Add newline | |
return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>" | |
# Use the correct space name for this Gradio app | |
space_name = "greendra-flux-1-schnell-serverless" # Updated space_name | |
relative_file_url = f"gradio_api/file={save_path}" | |
encoded_file_url = quote(relative_file_url) | |
arinteli_url = f"{ARINTELI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}" | |
download_html = ( | |
f'<div style="text-align: right; margin-right: -8px;">' | |
f'<a href="{arinteli_url}" target="_blank" style="background: #3dd49f; color: white; padding: 7px 25px; border-radius: 6px; text-decoration: none;">' | |
f'Download Image' | |
f'</a>' | |
f'</div>' | |
) | |
# *** SUCCESS LOG *** (Updated format) | |
print("SUCCESS:") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" {arinteli_url}\n") # Add newline | |
return image, download_html | |
except (OSError, IOError) as save_err: | |
print("FAILED:") | |
print(f" Reason: Image Save IO Error") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Path: '{save_path}'") | |
print(f" Error: {save_err}\n") # Add newline | |
traceback.print_exc() | |
return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file. Details: {save_err}</p>" | |
except Exception as e: | |
print("FAILED:") | |
print(f" Reason: Link Creation/Save Unexpected Error") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Error: {e}\n") # Add newline | |
traceback.print_exc() | |
return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>" | |
# --- Exception Handling for API Call --- (Updated format) | |
except requests.exceptions.Timeout: | |
print("FAILED:") | |
print(f" Reason: HF API Timeout") | |
print(f" Prompt: '{original_prompt}'\n") # Add newline | |
return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>" | |
except requests.exceptions.HTTPError as e: | |
status_code = e.response.status_code | |
error_text = e.response.text | |
error_data = {} | |
try: | |
error_data = e.response.json() | |
parsed_error = error_data.get('error', error_text) | |
if isinstance(parsed_error, dict) and 'message' in parsed_error: error_text = parsed_error['message'] | |
elif isinstance(parsed_error, list): error_text = "; ".join(map(str, parsed_error)) | |
else: error_text = str(parsed_error) | |
except json.JSONDecodeError: | |
pass | |
print("FAILED:") | |
print(f" Reason: HF API HTTP Error {status_code}") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Details: '{error_text[:200]}'\n") # Add newline | |
# User-facing messages (Keep consistent with previous version of this file) | |
if status_code == 503: | |
estimated_time = error_data.get("estimated_time") if isinstance(error_data, dict) else None | |
error_message = f"Model is loading (503), please wait." + (f" Est. time: {estimated_time:.1f}s." if estimated_time else "") + " Try again." | |
elif status_code == 400: error_message = f"Bad Request (400): Check parameters. API Error: {error_text}" | |
elif status_code == 422: error_message = f"Validation Error (422): Input invalid. API Error: {error_text}" | |
elif status_code == 401 or status_code == 403: error_message = f"Authorization Error ({status_code}): Check your API Token (HF_READ_TOKEN)." | |
else: error_message = f"API Error: {status_code}. Details: {error_text}" # Adjusted generic message slightly | |
return None, f"<p style='color: red; text-align: center;'>{error_message}</p>" | |
except Exception as e: | |
print("FAILED:") | |
print(f" Reason: Unexpected Error During Generation") | |
print(f" Prompt: '{original_prompt}'") | |
if translated_prompt != original_prompt: | |
print(f" Translated: '{translated_prompt}'") | |
print(f" Error: {e}\n") # Add newline | |
traceback.print_exc() | |
return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred: {e}</p>" # Keep specific error for user | |
# CSS to style the app | |
css = """ | |
#app-container { | |
max-width: 800px; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
textarea:focus { | |
background: #0d1117 !important; | |
} | |
#download-link-container p { | |
margin-top: 10px; | |
font-size: 0.9em; | |
} | |
""" | |
# Build the Gradio UI with Blocks | |
with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app: | |
with gr.Column(elem_id="app-container"): | |
with gr.Row(): | |
with gr.Column(elem_id="prompt-container"): | |
with gr.Row(): | |
text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=2, elem_id="prompt-text-input") | |
with gr.Row(): | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input") | |
with gr.Row(): | |
width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32) | |
height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32) | |
steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=8, step=1) | |
cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1) | |
seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed") | |
with gr.Row(): | |
text_button = gr.Button("Run", variant='primary', elem_id="gen-button") | |
# --- Output Components --- | |
with gr.Row(): | |
image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery", show_label=False, show_download_button=False, show_share_button=False, show_fullscreen_button=False) | |
with gr.Row(): | |
# HTML component to display status messages or the download link | |
download_link_display = gr.HTML(elem_id="download-link-container") | |
# Bind the button to the query function | |
text_button.click( | |
query, | |
# Ensure inputs match the query function definition | |
inputs=[text_prompt, negative_prompt, steps, cfg, seed, width, height], | |
# Outputs go to the image and HTML components | |
outputs=[image_output, download_link_display] | |
) | |
# Launch the Gradio app with allowed_paths | |
print("Starting Gradio app...") | |
app.launch( | |
show_api=False, | |
share=False, | |
allowed_paths=[absolute_image_dir] # Added allowed_paths | |
) |