File size: 14,800 Bytes
e547b24 fb64c67 e547b24 fb64c67 e547b24 fb64c67 6a90ab9 e547b24 fb64c67 e547b24 fb64c67 e547b24 fb64c67 5c1b65e fb64c67 e547b24 fb64c67 e547b24 fb64c67 e547b24 fb64c67 4d6cbec fb64c67 e547b24 fb64c67 e547b24 fb64c67 e547b24 fb64c67 5c1b65e fb64c67 5c1b65e fb64c67 e547b24 fb64c67 e547b24 fb64c67 e547b24 02f8cfa 4d6cbec 02f8cfa 73f7edc b64fadc fb64c67 e547b24 fb64c67 02f8cfa a6e241d fb64c67 02f8cfa fb64c67 02f8cfa fb64c67 02f8cfa fb64c67 4d6cbec 02f8cfa fb64c67 e547b24 fb64c67 02f8cfa fb64c67 02f8cfa fb64c67 e547b24 fb64c67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import gradio as gr
import requests
import io
import random
import os
import time
from PIL import Image, UnidentifiedImageError # Added UnidentifiedImageError
from deep_translator import GoogleTranslator
import json
import uuid
from urllib.parse import quote
import traceback # For detailed error logging
# Project by Nymbo
# --- Constants ---
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large-turbo"
API_TOKEN = os.getenv("HF_READ_TOKEN")
if not API_TOKEN:
print("WARNING: HF_READ_TOKEN environment variable not set. API calls may fail.")
# Optionally, raise an error or exit if the token is essential
# raise ValueError("Missing required environment variable: HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
timeout = 100 # seconds for API call timeout
IMAGE_DIR = "temp_generated_images" # Directory to store temporary images
ARINTELLI_REDIRECT_BASE = "https://arintelli.com/app/" # Your redirector URL
# --- Ensure temporary directory exists ---
try:
os.makedirs(IMAGE_DIR, exist_ok=True)
print(f"Confirmed temporary image directory exists: {IMAGE_DIR}")
except OSError as e:
print(f"ERROR: Could not create directory {IMAGE_DIR}: {e}")
# This is critical, so raise an error to prevent app start if dir fails
raise gr.Error(f"Fatal Error: Cannot create temporary image directory: {e}")
# --- Get Absolute Path for allowed_paths ---
# This needs to be done *before* calling launch()
absolute_image_dir = os.path.abspath(IMAGE_DIR)
print(f"Absolute path for allowed_paths: {absolute_image_dir}")
# --- Function to query the API and return the generated image and download link ---
def query(prompt, negative_prompt, steps=35, cfg_scale=7, seed=-1, width=1024, height=1024):
# Renamed `strength` input as it wasn't used in the payload for txt2img
# Removed `sampler` input as it wasn't used in payload
# Basic Input Validation
if not prompt or not prompt.strip():
print("Empty prompt received.")
# Return None for image and an informative message for the HTML component
return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"
key = random.randint(0, 999)
print(f"\n--- Generation {key} Started ---")
# Translation
try:
# Using 'auto' source detection
translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
except Exception as e:
print(f"Translation failed: {e}. Using original prompt.")
translated_prompt = prompt # Fallback to original if translation fails
# Add suffix to prompt
final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
print(f'Generation {key} prompt: {final_prompt}')
# Prepare payload for API call
payload = {
"inputs": final_prompt,
"negative_prompt": negative_prompt,
"steps": steps,
"guidance_scale": cfg_scale, # API often uses guidance_scale
"seed": seed if seed != -1 else random.randint(1, 1000000000),
"parameters": { # Nested parameters as per original structure
"width": width,
"height": height,
}
# Add other parameters here if needed (e.g., sampler if supported)
}
# API Call Section
try:
if not headers:
print("WARNING: Authorization header is missing (HF_READ_TOKEN not set?)")
# Handle error appropriately - maybe return an error message
return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
response.raise_for_status() # Raises HTTPError for 4xx/5xx responses
image_bytes = response.content
# Check for valid image data before proceeding
if not image_bytes or len(image_bytes) < 100: # Basic check for empty/tiny response
print(f"Error: Received empty or very small response content (length: {len(image_bytes)}). Potential API issue.")
return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"
try:
image = Image.open(io.BytesIO(image_bytes))
except UnidentifiedImageError as img_err:
print(f"Error: Could not identify or open image from API response bytes: {img_err}")
# Optionally save the invalid bytes for debugging
# error_bytes_path = os.path.join(IMAGE_DIR, f"error_{key}_bytes.bin")
# with open(error_bytes_path, "wb") as f: f.write(image_bytes)
# print(f"Saved problematic bytes to {error_bytes_path}")
return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"
# --- Save image and create download link ---
filename = f"{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
# save_path is relative to the script's execution directory
save_path = os.path.join(IMAGE_DIR, filename)
absolute_save_path = os.path.abspath(save_path) # Get absolute path for logging
try:
# Save image explicitly as PNG
image.save(save_path, "PNG")
# *** Verify file exists after saving ***
if os.path.exists(save_path):
file_size = os.path.getsize(save_path)
if file_size < 100: # Warn if the saved file is suspiciously small
print(f"WARNING: Saved file {save_path} is very small ({file_size} bytes). May indicate an issue.")
# Optionally return a warning message in the UI
# return image, "<p style='color: orange; text-align: center;'>Warning: Saved image file is unexpectedly small.</p>"
else:
# This indicates a serious problem if save() didn't raise an error but the file isn't there
print(f"CRITICAL ERROR: File NOT found at {save_path} (Absolute: {absolute_save_path}) immediately after saving!")
return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"
# Get current space name from the API URL
space_name = "greendra-stable-diffusion-3-5-large-serverless"
relative_file_url = f"/gradio_api/file={save_path}"
encoded_file_url = quote(relative_file_url)
# Add space_name parameter to the URL
arintelli_url = f"{ARINTELLI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"
print(f"{arintelli_url}")
# Use simpler button style like the Run button
download_html = (
f'<div style="text-align: center;">'
f'<a href="{arintelli_url}" target="_blank" class="gr-button gr-button-lg gr-button-primary">'
f'Download Image'
f'</a>'
f'</div>'
)
print(f"--- Generation {key} Done ---")
return image, download_html
except (OSError, IOError) as save_err:
# Handle errors during the file save operation
print(f"CRITICAL ERROR: Failed to save image to {save_path} (Absolute: {absolute_save_path}): {save_err}")
traceback.print_exc() # Log detailed traceback
return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file. Details: {save_err}</p>"
except Exception as e:
# Catch any other unexpected errors during link creation/saving
print(f"Error during link creation or unexpected save issue: {e}")
traceback.print_exc()
# Return the generated image (if available) but indicate link error
return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"
# --- Exception Handling for API Call ---
except requests.exceptions.Timeout:
print(f"Error: Request timed out after {timeout} seconds.")
return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
except requests.exceptions.HTTPError as e:
# Handle HTTP errors from the API (4xx, 5xx)
status_code = e.response.status_code
error_text = e.response.text # Default error text
try:
# Try to parse more specific error message from JSON response
error_data = e.response.json()
error_text = error_data.get('error', error_text)
if isinstance(error_text, dict) and 'message' in error_text:
error_text = error_text['message'] # Handle nested messages
except json.JSONDecodeError:
pass # Keep raw text if not JSON
print(f"Error: Failed API call. Status: {status_code}, Response: {error_text}")
# Generate user-friendly messages based on status code
if status_code == 503: # Service Unavailable (often model loading)
estimated_time = error_data.get("estimated_time") if 'error_data' in locals() and isinstance(error_data, dict) else None
if estimated_time:
error_message = f"Model is loading (503), please wait. Est. time: {estimated_time:.1f}s. Try again."
else:
error_message = f"Service unavailable (503). Model might be loading or down. Try again later."
elif status_code == 400: # Bad Request (invalid parameters)
error_message = f"Bad Request (400): Check parameters. API Error: {error_text}"
elif status_code == 422: # Unprocessable Entity (validation error)
error_message = f"Validation Error (422): Input invalid. API Error: {error_text}"
elif status_code == 401 or status_code == 403: # Unauthorized / Forbidden
error_message = f"Authorization Error ({status_code}): Check your API Token (HF_READ_TOKEN)."
else: # Generic API error
error_message = f"API Error: {status_code}. Details: {error_text}"
# Return None for image, and the error message string for the HTML component
return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
except Exception as e:
# Catch any other unexpected errors during the process
print(f"An unexpected error occurred: {e}")
traceback.print_exc() # Log detailed traceback
return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred: {e}</p>"
# --- CSS Styling ---
css = """
#app-container {
max-width: 800px;
margin-left: auto;
margin-right: auto;
}
textarea:focus {
background: #0d1117 !important;
}
#download-link-container p { /* Style the link container */
margin-top: 10px; /* Add some space above the link */
font-size: 0.9em; /* Slightly smaller text for the link message */
}
"""
# --- Build the Gradio UI with Blocks ---
with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
gr.HTML("<center><h1>Stable Diffusion 3.5 Turbo</h1></center>")
with gr.Column(elem_id="app-container"):
# --- Input Components ---
with gr.Row():
with gr.Column(elem_id="prompt-container"):
with gr.Row():
text_prompt = gr.Textbox(
label="Prompt",
placeholder="Enter a prompt here",
lines=2,
elem_id="prompt-text-input"
)
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="What should not be in the image",
value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos",
lines=3,
elem_id="negative-prompt-text-input"
)
with gr.Row():
width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
steps = gr.Slider(label="Sampling steps", value=35, minimum=1, maximum=100, step=1)
cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=7, minimum=1, maximum=20, step=1)
# Removed 'strength' slider as it wasn't used in query payload
# strength = gr.Slider(label="Strength (Primarily for Img2Img)", value=0.7, minimum=0, maximum=1, step=0.001, info="Note: Strength is mainly used in Image-to-Image generation.")
seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")
# Removed 'method' radio as it wasn't used in query payload
# method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"], info="Note: Sampler choice might not be supported by this API.")
# --- Action Button ---
with gr.Row():
text_button = gr.Button("Run", variant='primary', elem_id="gen-button")
# --- Output Components ---
with gr.Row():
image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
with gr.Row():
# HTML component to display status messages or the download link
download_link_display = gr.HTML(elem_id="download-link-container")
# --- Event Listener ---
# Bind the button click to the query function
text_button.click(
query,
# Ensure the inputs list matches the parameters of the `query` function definition
inputs=[text_prompt, negative_prompt, steps, cfg, seed, width, height],
# Outputs go to the image component and the HTML component
outputs=[image_output, download_link_display]
)
# --- Launch the Gradio app ---
print("Starting Gradio app...")
# Use allowed_paths with the pre-calculated absolute path to the image directory
app.launch(
show_api=False,
share=False, # Set to True only if you need a public link for direct testing
allowed_paths=[absolute_image_dir]
) |