greendra's picture
Update app.py
aa7effe verified
raw
history blame
16.6 kB
import gradio as gr
import requests
import io
import random
import os
import time
from PIL import Image, UnidentifiedImageError # Added UnidentifiedImageError
from deep_translator import GoogleTranslator
import json
import uuid
from urllib.parse import quote
import traceback # For detailed error logging
from openai import OpenAI, RateLimitError, APIConnectionError, AuthenticationError # Added OpenAI errors
# Project by Nymbo
# --- Constants ---
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large-turbo"
API_TOKEN = os.getenv("HF_READ_TOKEN")
if not API_TOKEN:
print("WARNING: HF_READ_TOKEN environment variable not set. API calls may fail.")
# Optionally, raise an error or exit if the token is essential
# raise ValueError("Missing required environment variable: HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
timeout = 100 # seconds for API call timeout
IMAGE_DIR = "temp_generated_images" # Directory to store temporary images
ARINTELI_REDIRECT_BASE = "https://arinteli.com/app/" # Your redirector URL
# --- Ensure temporary directory exists ---
try:
os.makedirs(IMAGE_DIR, exist_ok=True)
print(f"Confirmed temporary image directory exists: {IMAGE_DIR}")
except OSError as e:
print(f"ERROR: Could not create directory {IMAGE_DIR}: {e}")
# This is critical, so raise an error to prevent app start if dir fails
raise gr.Error(f"Fatal Error: Cannot create temporary image directory: {e}")
# --- Get Absolute Path for allowed_paths ---
# This needs to be done *before* calling launch()
absolute_image_dir = os.path.abspath(IMAGE_DIR)
print(f"Absolute path for allowed_paths: {absolute_image_dir}")
# --- OpenAI Client ---
# Initialize OpenAI client (it will automatically look for the OPENAI_API_KEY env var)
try:
# Check if the environment variable exists and is not empty
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
print("WARNING: OPENAI_API_KEY environment variable not set or empty. Moderation will be skipped.")
openai_client = None
else:
openai_client = OpenAI() # Key is implicitly used by the library
print("OpenAI client initialized successfully.")
except Exception as e:
print(f"ERROR: Failed to initialize OpenAI client: {e}. Moderation will be skipped.")
openai_client = None # Set to None so we can check later
# --- Function to query the API and return the generated image and download link ---
def query(prompt, negative_prompt, steps=4, cfg_scale=0, seed=-1, width=1024, height=1024):
# Basic Input Validation
if not prompt or not prompt.strip():
print("WARNING: Empty prompt received.\n") # Add newline for separation
return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"
# Store original prompt for logging
original_prompt = prompt
# Translation
translated_prompt = prompt # Start with original
try:
translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
except Exception as e:
print(f"WARNING: Translation failed. Using original prompt.")
print(f" Error: {e}")
print(f" Prompt: '{original_prompt}'\n") # Add newline
translated_prompt = prompt # Ensure it's the original if translation fails
# --- OpenAI Moderation Check ---
if openai_client:
try:
mod_response = openai_client.moderations.create(
model="omni-moderation-latest",
input=translated_prompt
)
result = mod_response.results[0]
if result.categories.sexual_minors:
print("BLOCKED:")
print(f" Reason: sexual/minors")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print("") # Add newline
return None, "<p style='color: red; text-align: center;'>Prompt violates safety guidelines. Generation blocked.</p>"
except AuthenticationError:
print("BLOCKED:")
print(f" Reason: OpenAI Auth Error")
print(f" Prompt: '{original_prompt}'\n") # Add newline
return None, "<p style='color: red; text-align: center;'>Safety check failed. Generation blocked.</p>"
except RateLimitError:
print("BLOCKED:")
print(f" Reason: OpenAI Rate Limit")
print(f" Prompt: '{original_prompt}'\n") # Add newline
return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
except APIConnectionError as e:
print("BLOCKED:")
print(f" Reason: OpenAI Connection Error")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Error: {e}\n") # Add newline
return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
except Exception as e:
print("BLOCKED:")
print(f" Reason: OpenAI Unexpected Error")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Error: {e}\n") # Add newline
traceback.print_exc() # Keep traceback for unexpected errors
return None, "<p style='color: red; text-align: center;'>An unexpected error occurred during safety check. Generation blocked.</p>"
else:
print(f"WARNING: OpenAI client not available. Skipping moderation.")
print(f" Prompt: '{original_prompt}'") # Log original prompt even if skipping
if translated_prompt != original_prompt:
print(f" (Would use translated: '{translated_prompt}')")
print("") # Add newline
# --- Proceed with Generation ---
final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
payload = {
"inputs": final_prompt,
"parameters": {
"width": width, "height": height, "num_inference_steps": steps,
"negative_prompt": negative_prompt, "guidance_scale": cfg_scale,
"seed": seed if seed != -1 else random.randint(1, 1000000000),
}
}
# API Call Section
try:
if not headers:
print("FAILED:")
print(f" Reason: HF Token Missing")
print(f" Prompt: '{original_prompt}'\n") # Add newline
return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
response.raise_for_status()
image_bytes = response.content
if not image_bytes or len(image_bytes) < 100:
print("FAILED:")
print(f" Reason: Invalid Image Data (Empty/Small)")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Length: {len(image_bytes)}\n") # Add newline
return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"
try:
image = Image.open(io.BytesIO(image_bytes))
except UnidentifiedImageError as img_err:
print("FAILED:")
print(f" Reason: Image Processing Error")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Error: {img_err}\n") # Add newline
return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"
# --- Save image and create download link ---
filename = f"{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
save_path = os.path.join(IMAGE_DIR, filename)
absolute_save_path = os.path.abspath(save_path)
try:
image.save(save_path, "PNG")
if not os.path.exists(save_path) or os.path.getsize(save_path) < 100:
print("FAILED:")
print(f" Reason: Image Save Verification Error")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Path: '{save_path}'\n") # Add newline
return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"
space_name = "greendra-stable-diffusion-3-5-large-serverless"
relative_file_url = f"gradio_api/file={save_path}"
encoded_file_url = quote(relative_file_url)
arinteli_url = f"{ARINTELI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"
download_html = (
f'<div style="text-align: center;">'
f'<a href="{arinteli_url}" target="_blank" class="gr-button gr-button-lg gr-button-primary">'
f'Download Image'
f'</a>'
f'</div>'
)
# *** SUCCESS LOG ***
print("SUCCESS:")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" {arinteli_url}\n") # Add newline
return image, download_html
except (OSError, IOError) as save_err:
print("FAILED:")
print(f" Reason: Image Save IO Error")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Path: '{save_path}'")
print(f" Error: {save_err}\n") # Add newline
traceback.print_exc()
return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file.</p>"
except Exception as e:
print("FAILED:")
print(f" Reason: Link Creation/Save Unexpected Error")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Error: {e}\n") # Add newline
traceback.print_exc()
return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"
# --- Exception Handling for API Call ---
except requests.exceptions.Timeout:
print("FAILED:")
print(f" Reason: HF API Timeout")
print(f" Prompt: '{original_prompt}'\n") # Add newline
return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
error_text = e.response.text
error_data = {}
try:
error_data = e.response.json()
parsed_error = error_data.get('error', error_text)
if isinstance(parsed_error, dict) and 'message' in parsed_error: error_text = parsed_error['message']
elif isinstance(parsed_error, list): error_text = "; ".join(map(str, parsed_error))
else: error_text = str(parsed_error)
except json.JSONDecodeError:
pass
print("FAILED:")
print(f" Reason: HF API HTTP Error {status_code}")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Details: '{error_text[:200]}'\n") # Add newline
# User-facing messages remain the same
if status_code == 503:
estimated_time = error_data.get("estimated_time") if isinstance(error_data, dict) else None
error_message = f"Model is loading (503), please wait." + (f" Est. time: {estimated_time:.1f}s." if estimated_time else "") + " Try again."
elif status_code == 400: error_message = f"Bad Request (400): Check parameters."
elif status_code == 422: error_message = f"Validation Error (422): Input invalid."
elif status_code == 401 or status_code == 403: error_message = f"Authorization Error ({status_code}): Check API Token."
elif status_code == 429: error_message = f"Rate Limit Error (429): Too many requests. Try again later."
else: error_message = f"API Error ({status_code}). Please try again."
return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
except Exception as e:
print("FAILED:")
print(f" Reason: Unexpected Error During Generation")
print(f" Prompt: '{original_prompt}'")
if translated_prompt != original_prompt:
print(f" Translated: '{translated_prompt}'")
print(f" Error: {e}\n") # Add newline
traceback.print_exc()
return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred. Please check logs.</p>"
# --- CSS Styling ---
css = """
#app-container {
max-width: 800px;
margin-left: auto;
margin-right: auto;
}
textarea:focus {
background: #0d1117 !important;
}
#download-link-container p { /* Style the link container */
margin-top: 10px; /* Add some space above the link */
font-size: 0.9em; /* Slightly smaller text for the link message */
}
"""
# --- Build the Gradio UI with Blocks ---
with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
with gr.Column(elem_id="app-container"):
# --- Input Components ---
with gr.Row():
with gr.Column(elem_id="prompt-container"):
with gr.Row():
text_prompt = gr.Textbox(
label="Prompt",
placeholder="Enter a prompt here",
lines=2,
elem_id="prompt-text-input"
)
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="What should not be in the image",
value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos",
lines=3,
elem_id="negative-prompt-text-input"
)
with gr.Row():
width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=8, step=1)
cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1)
seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")
# --- Action Button ---
with gr.Row():
text_button = gr.Button("Run", variant='primary', elem_id="gen-button")
# --- Output Components ---
with gr.Row():
image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
with gr.Row():
download_link_display = gr.HTML(elem_id="download-link-container")
# --- Event Listener ---
text_button.click(
query,
inputs=[text_prompt, negative_prompt, steps, cfg, seed, width, height],
outputs=[image_output, download_link_display]
)
# --- Launch the Gradio app ---
print("Starting Gradio app...")
app.launch(
show_api=False,
share=False,
allowed_paths=[absolute_image_dir],
# server_name="0.0.0.0"
)