greendra commited on
Commit
1b6cf25
·
verified ·
1 Parent(s): 7ab7d38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -104
app.py CHANGED
@@ -10,6 +10,7 @@ import json
10
  import uuid
11
  from urllib.parse import quote
12
  import traceback # For detailed error logging
 
13
 
14
  # Project by Nymbo
15
 
@@ -40,106 +41,143 @@ except OSError as e:
40
  absolute_image_dir = os.path.abspath(IMAGE_DIR)
41
  print(f"Absolute path for allowed_paths: {absolute_image_dir}")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # --- Function to query the API and return the generated image and download link ---
45
  def query(prompt, negative_prompt, steps=4, cfg_scale=0, seed=-1, width=1024, height=1024):
46
- # Renamed `strength` input as it wasn't used in the payload for txt2img
47
- # Removed `sampler` input as it wasn't used in payload
48
-
49
  # Basic Input Validation
50
  if not prompt or not prompt.strip():
51
- print("Empty prompt received.")
52
- # Return None for image and an informative message for the HTML component
53
  return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"
54
 
55
- key = random.randint(0, 999)
56
- print(f"\n--- Generation {key} Started ---")
57
 
58
  # Translation
59
  try:
60
- # Using 'auto' source detection
61
  translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
62
  except Exception as e:
63
- print(f"Translation failed: {e}. Using original prompt.")
64
- translated_prompt = prompt # Fallback to original if translation fails
 
 
65
 
66
- # Add suffix to prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
68
- print(f'Generation {key} prompt: {final_prompt}')
69
 
70
- # Prepare payload for API call
71
  payload = {
72
  "inputs": final_prompt,
73
- "parameters": { # Nested parameters as per original structure
74
- "width": width,
75
- "height": height,
76
- "num_inference_steps": steps,
77
- "negative_prompt": negative_prompt,
78
- "guidance_scale": cfg_scale,
79
  "seed": seed if seed != -1 else random.randint(1, 1000000000),
80
  }
81
- # Add other parameters here if needed (e.g., sampler if supported)
82
  }
83
 
84
  # API Call Section
85
  try:
86
  if not headers:
87
- print("WARNING: Authorization header is missing (HF_READ_TOKEN not set?)")
88
- # Handle error appropriately - maybe return an error message
 
89
  return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"
90
 
91
  response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
92
- response.raise_for_status() # Raises HTTPError for 4xx/5xx responses
93
 
94
  image_bytes = response.content
95
- # Check for valid image data before proceeding
96
- if not image_bytes or len(image_bytes) < 100: # Basic check for empty/tiny response
97
- print(f"Error: Received empty or very small response content (length: {len(image_bytes)}). Potential API issue.")
 
 
98
  return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"
99
 
100
  try:
101
  image = Image.open(io.BytesIO(image_bytes))
102
  except UnidentifiedImageError as img_err:
103
- print(f"Error: Could not identify or open image from API response bytes: {img_err}")
104
- # Optionally save the invalid bytes for debugging
105
- # error_bytes_path = os.path.join(IMAGE_DIR, f"error_{key}_bytes.bin")
106
- # with open(error_bytes_path, "wb") as f: f.write(image_bytes)
107
- # print(f"Saved problematic bytes to {error_bytes_path}")
108
  return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"
109
 
110
  # --- Save image and create download link ---
111
  filename = f"{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
112
- # save_path is relative to the script's execution directory
113
  save_path = os.path.join(IMAGE_DIR, filename)
114
- absolute_save_path = os.path.abspath(save_path) # Get absolute path for logging
115
 
116
  try:
117
- # Save image explicitly as PNG
118
  image.save(save_path, "PNG")
119
 
120
- # *** Verify file exists after saving ***
121
- if os.path.exists(save_path):
122
- file_size = os.path.getsize(save_path)
123
- if file_size < 100: # Warn if the saved file is suspiciously small
124
- print(f"WARNING: Saved file {save_path} is very small ({file_size} bytes). May indicate an issue.")
125
- # Optionally return a warning message in the UI
126
- # return image, "<p style='color: orange; text-align: center;'>Warning: Saved image file is unexpectedly small.</p>"
127
- else:
128
- # This indicates a serious problem if save() didn't raise an error but the file isn't there
129
- print(f"CRITICAL ERROR: File NOT found at {save_path} (Absolute: {absolute_save_path}) immediately after saving!")
130
- return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"
131
-
132
- # Get current space name from the API URL
133
- space_name = "greendra-stable-diffusion-3-5-large-serverless"
134
-
135
- relative_file_url = f"/gradio_api/file={save_path}"
136
 
 
 
137
  encoded_file_url = quote(relative_file_url)
138
- # Add space_name parameter to the URL
139
  arinteli_url = f"{ARINTELI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"
140
- print(f"{arinteli_url}")
141
 
142
- # Use simpler button style like the Run button
143
  download_html = (
144
  f'<div style="text-align: center;">'
145
  f'<a href="{arinteli_url}" target="_blank" class="gr-button gr-button-lg gr-button-primary">'
@@ -148,63 +186,70 @@ def query(prompt, negative_prompt, steps=4, cfg_scale=0, seed=-1, width=1024, he
148
  f'</div>'
149
  )
150
 
151
- print(f"--- Generation {key} Done ---")
 
 
 
152
  return image, download_html
153
 
154
  except (OSError, IOError) as save_err:
155
- # Handle errors during the file save operation
156
- print(f"CRITICAL ERROR: Failed to save image to {save_path} (Absolute: {absolute_save_path}): {save_err}")
157
- traceback.print_exc() # Log detailed traceback
158
- return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file. Details: {save_err}</p>"
 
 
 
159
  except Exception as e:
160
- # Catch any other unexpected errors during link creation/saving
161
- print(f"Error during link creation or unexpected save issue: {e}")
 
 
162
  traceback.print_exc()
163
- # Return the generated image (if available) but indicate link error
164
  return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"
165
 
166
  # --- Exception Handling for API Call ---
167
  except requests.exceptions.Timeout:
168
- print(f"Error: Request timed out after {timeout} seconds.")
 
 
169
  return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
170
  except requests.exceptions.HTTPError as e:
171
- # Handle HTTP errors from the API (4xx, 5xx)
172
  status_code = e.response.status_code
173
- error_text = e.response.text # Default error text
 
174
  try:
175
- # Try to parse more specific error message from JSON response
176
  error_data = e.response.json()
177
- error_text = error_data.get('error', error_text)
178
- if isinstance(error_text, dict) and 'message' in error_text:
179
- error_text = error_text['message'] # Handle nested messages
 
180
  except json.JSONDecodeError:
181
- pass # Keep raw text if not JSON
182
-
183
- print(f"Error: Failed API call. Status: {status_code}, Response: {error_text}")
184
-
185
- # Generate user-friendly messages based on status code
186
- if status_code == 503: # Service Unavailable (often model loading)
187
- estimated_time = error_data.get("estimated_time") if 'error_data' in locals() and isinstance(error_data, dict) else None
188
- if estimated_time:
189
- error_message = f"Model is loading (503), please wait. Est. time: {estimated_time:.1f}s. Try again."
190
- else:
191
- error_message = f"Service unavailable (503). Model might be loading or down. Try again later."
192
- elif status_code == 400: # Bad Request (invalid parameters)
193
- error_message = f"Bad Request (400): Check parameters. API Error: {error_text}"
194
- elif status_code == 422: # Unprocessable Entity (validation error)
195
- error_message = f"Validation Error (422): Input invalid. API Error: {error_text}"
196
- elif status_code == 401 or status_code == 403: # Unauthorized / Forbidden
197
- error_message = f"Authorization Error ({status_code}): Check your API Token (HF_READ_TOKEN)."
198
- else: # Generic API error
199
- error_message = f"API Error: {status_code}. Details: {error_text}"
200
-
201
- # Return None for image, and the error message string for the HTML component
202
  return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
203
  except Exception as e:
204
- # Catch any other unexpected errors during the process
205
- print(f"An unexpected error occurred: {e}")
206
- traceback.print_exc() # Log detailed traceback
207
- return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred: {e}</p>"
 
 
208
 
209
 
210
  # --- CSS Styling ---
@@ -252,11 +297,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
252
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
253
  steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=8, step=1)
254
  cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1)
255
- # Removed 'strength' slider as it wasn't used in query payload
256
- # strength = gr.Slider(label="Strength (Primarily for Img2Img)", value=0.7, minimum=0, maximum=1, step=0.001, info="Note: Strength is mainly used in Image-to-Image generation.")
257
  seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")
258
- # Removed 'method' radio as it wasn't used in query payload
259
- # method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"], info="Note: Sampler choice might not be supported by this API.")
260
 
261
  # --- Action Button ---
262
  with gr.Row():
@@ -266,24 +307,20 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
266
  with gr.Row():
267
  image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
268
  with gr.Row():
269
- # HTML component to display status messages or the download link
270
  download_link_display = gr.HTML(elem_id="download-link-container")
271
 
272
  # --- Event Listener ---
273
- # Bind the button click to the query function
274
  text_button.click(
275
  query,
276
- # Ensure the inputs list matches the parameters of the `query` function definition
277
  inputs=[text_prompt, negative_prompt, steps, cfg, seed, width, height],
278
- # Outputs go to the image component and the HTML component
279
  outputs=[image_output, download_link_display]
280
  )
281
 
282
  # --- Launch the Gradio app ---
283
  print("Starting Gradio app...")
284
- # Use allowed_paths with the pre-calculated absolute path to the image directory
285
  app.launch(
286
  show_api=False,
287
- share=False, # Set to True only if you need a public link for direct testing
288
- allowed_paths=[absolute_image_dir]
 
289
  )
 
10
  import uuid
11
  from urllib.parse import quote
12
  import traceback # For detailed error logging
13
+ from openai import OpenAI, RateLimitError, APIConnectionError, AuthenticationError # Added OpenAI errors
14
 
15
  # Project by Nymbo
16
 
 
41
  absolute_image_dir = os.path.abspath(IMAGE_DIR)
42
  print(f"Absolute path for allowed_paths: {absolute_image_dir}")
43
 
44
+ # --- OpenAI Client ---
45
+ # Initialize OpenAI client (it will automatically look for the OPENAI_API_KEY env var)
46
+ try:
47
+ # Check if the environment variable exists and is not empty
48
+ openai_api_key = os.getenv("OPENAI_API_KEY")
49
+ if not openai_api_key:
50
+ print("WARNING: OPENAI_API_KEY environment variable not set or empty. Moderation will be skipped.")
51
+ openai_client = None
52
+ else:
53
+ openai_client = OpenAI() # Key is implicitly used by the library
54
+ print("OpenAI client initialized successfully.")
55
+ except Exception as e:
56
+ print(f"ERROR: Failed to initialize OpenAI client: {e}. Moderation will be skipped.")
57
+ openai_client = None # Set to None so we can check later
58
 
59
  # --- Function to query the API and return the generated image and download link ---
60
  def query(prompt, negative_prompt, steps=4, cfg_scale=0, seed=-1, width=1024, height=1024):
 
 
 
61
  # Basic Input Validation
62
  if not prompt or not prompt.strip():
63
+ print("WARNING: Empty prompt received.\n") # Add newline for separation
 
64
  return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"
65
 
66
+ # Store original prompt for logging
67
+ original_prompt = prompt
68
 
69
  # Translation
70
  try:
 
71
  translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
72
  except Exception as e:
73
+ print(f"WARNING: Translation failed. Using original prompt.")
74
+ print(f" Error: {e}")
75
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
76
+ translated_prompt = prompt
77
 
78
+ # --- OpenAI Moderation Check ---
79
+ if openai_client:
80
+ try:
81
+ mod_response = openai_client.moderations.create(
82
+ model="omni-moderation-latest",
83
+ input=translated_prompt
84
+ )
85
+ result = mod_response.results[0]
86
+
87
+ if result.categories.sexual_minors:
88
+ print("BLOCKED:")
89
+ print(f" Reason: sexual/minors")
90
+ print(f" Prompt: '{original_prompt}'")
91
+ print(f" Translated: '{translated_prompt}'\n") # Add newline
92
+ return None, "<p style='color: red; text-align: center;'>Prompt violates safety guidelines. Generation blocked.</p>"
93
+
94
+ except AuthenticationError:
95
+ print("BLOCKED:")
96
+ print(f" Reason: OpenAI Auth Error")
97
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
98
+ return None, "<p style='color: red; text-align: center;'>Safety check failed. Generation blocked.</p>"
99
+ except RateLimitError:
100
+ print("BLOCKED:")
101
+ print(f" Reason: OpenAI Rate Limit")
102
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
103
+ return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
104
+ except APIConnectionError as e:
105
+ print("BLOCKED:")
106
+ print(f" Reason: OpenAI Connection Error")
107
+ print(f" Prompt: '{original_prompt}'")
108
+ print(f" Error: {e}\n") # Add newline
109
+ return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
110
+ except Exception as e:
111
+ print("BLOCKED:")
112
+ print(f" Reason: OpenAI Unexpected Error")
113
+ print(f" Prompt: '{original_prompt}'")
114
+ print(f" Error: {e}\n") # Add newline
115
+ traceback.print_exc() # Keep traceback for unexpected errors
116
+ return None, "<p style='color: red; text-align: center;'>An unexpected error occurred during safety check. Generation blocked.</p>"
117
+ else:
118
+ print(f"WARNING: OpenAI client not available. Skipping moderation.")
119
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
120
+
121
+ # --- Proceed with Generation ---
122
  final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
 
123
 
 
124
  payload = {
125
  "inputs": final_prompt,
126
+ "parameters": {
127
+ "width": width, "height": height, "num_inference_steps": steps,
128
+ "negative_prompt": negative_prompt, "guidance_scale": cfg_scale,
 
 
 
129
  "seed": seed if seed != -1 else random.randint(1, 1000000000),
130
  }
 
131
  }
132
 
133
  # API Call Section
134
  try:
135
  if not headers:
136
+ print("FAILED:")
137
+ print(f" Reason: HF Token Missing")
138
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
139
  return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"
140
 
141
  response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
142
+ response.raise_for_status()
143
 
144
  image_bytes = response.content
145
+ if not image_bytes or len(image_bytes) < 100:
146
+ print("FAILED:")
147
+ print(f" Reason: Invalid Image Data (Empty/Small)")
148
+ print(f" Prompt: '{original_prompt}'")
149
+ print(f" Length: {len(image_bytes)}\n") # Add newline
150
  return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"
151
 
152
  try:
153
  image = Image.open(io.BytesIO(image_bytes))
154
  except UnidentifiedImageError as img_err:
155
+ print("FAILED:")
156
+ print(f" Reason: Image Processing Error")
157
+ print(f" Prompt: '{original_prompt}'")
158
+ print(f" Error: {img_err}\n") # Add newline
 
159
  return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"
160
 
161
  # --- Save image and create download link ---
162
  filename = f"{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
 
163
  save_path = os.path.join(IMAGE_DIR, filename)
164
+ absolute_save_path = os.path.abspath(save_path)
165
 
166
  try:
 
167
  image.save(save_path, "PNG")
168
 
169
+ if not os.path.exists(save_path) or os.path.getsize(save_path) < 100:
170
+ print("FAILED:")
171
+ print(f" Reason: Image Save Verification Error")
172
+ print(f" Prompt: '{original_prompt}'")
173
+ print(f" Path: '{save_path}'\n") # Add newline
174
+ return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"
 
 
 
 
 
 
 
 
 
 
175
 
176
+ space_name = "greendra-stable-diffusion-3-5-large-serverless"
177
+ relative_file_url = f"file={save_path}"
178
  encoded_file_url = quote(relative_file_url)
 
179
  arinteli_url = f"{ARINTELI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"
 
180
 
 
181
  download_html = (
182
  f'<div style="text-align: center;">'
183
  f'<a href="{arinteli_url}" target="_blank" class="gr-button gr-button-lg gr-button-primary">'
 
186
  f'</div>'
187
  )
188
 
189
+ # *** SUCCESS LOG ***
190
+ print("SUCCESS:")
191
+ print(f" Prompt: '{original_prompt}'")
192
+ print(f" {arinteli_url}\n") # Add newline
193
  return image, download_html
194
 
195
  except (OSError, IOError) as save_err:
196
+ print("FAILED:")
197
+ print(f" Reason: Image Save IO Error")
198
+ print(f" Prompt: '{original_prompt}'")
199
+ print(f" Path: '{save_path}'")
200
+ print(f" Error: {save_err}\n") # Add newline
201
+ traceback.print_exc()
202
+ return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file.</p>"
203
  except Exception as e:
204
+ print("FAILED:")
205
+ print(f" Reason: Link Creation/Save Unexpected Error")
206
+ print(f" Prompt: '{original_prompt}'")
207
+ print(f" Error: {e}\n") # Add newline
208
  traceback.print_exc()
 
209
  return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"
210
 
211
  # --- Exception Handling for API Call ---
212
  except requests.exceptions.Timeout:
213
+ print("FAILED:")
214
+ print(f" Reason: HF API Timeout")
215
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
216
  return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
217
  except requests.exceptions.HTTPError as e:
 
218
  status_code = e.response.status_code
219
+ error_text = e.response.text
220
+ error_data = {}
221
  try:
 
222
  error_data = e.response.json()
223
+ parsed_error = error_data.get('error', error_text)
224
+ if isinstance(parsed_error, dict) and 'message' in parsed_error: error_text = parsed_error['message']
225
+ elif isinstance(parsed_error, list): error_text = "; ".join(map(str, parsed_error))
226
+ else: error_text = str(parsed_error)
227
  except json.JSONDecodeError:
228
+ pass
229
+
230
+ print("FAILED:")
231
+ print(f" Reason: HF API HTTP Error {status_code}")
232
+ print(f" Prompt: '{original_prompt}'")
233
+ print(f" Details: '{error_text[:200]}'\n") # Add newline
234
+
235
+ # User-facing messages remain the same
236
+ if status_code == 503:
237
+ estimated_time = error_data.get("estimated_time") if isinstance(error_data, dict) else None
238
+ error_message = f"Model is loading (503), please wait." + (f" Est. time: {estimated_time:.1f}s." if estimated_time else "") + " Try again."
239
+ elif status_code == 400: error_message = f"Bad Request (400): Check parameters."
240
+ elif status_code == 422: error_message = f"Validation Error (422): Input invalid."
241
+ elif status_code == 401 or status_code == 403: error_message = f"Authorization Error ({status_code}): Check API Token."
242
+ elif status_code == 429: error_message = f"Rate Limit Error (429): Too many requests. Try again later."
243
+ else: error_message = f"API Error ({status_code}). Please try again."
244
+
 
 
 
 
245
  return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
246
  except Exception as e:
247
+ print("FAILED:")
248
+ print(f" Reason: Unexpected Error During Generation")
249
+ print(f" Prompt: '{original_prompt}'")
250
+ print(f" Error: {e}\n") # Add newline
251
+ traceback.print_exc()
252
+ return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred. Please check logs.</p>"
253
 
254
 
255
  # --- CSS Styling ---
 
297
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
298
  steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=8, step=1)
299
  cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1)
 
 
300
  seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")
 
 
301
 
302
  # --- Action Button ---
303
  with gr.Row():
 
307
  with gr.Row():
308
  image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
309
  with gr.Row():
 
310
  download_link_display = gr.HTML(elem_id="download-link-container")
311
 
312
  # --- Event Listener ---
 
313
  text_button.click(
314
  query,
 
315
  inputs=[text_prompt, negative_prompt, steps, cfg, seed, width, height],
 
316
  outputs=[image_output, download_link_display]
317
  )
318
 
319
  # --- Launch the Gradio app ---
320
  print("Starting Gradio app...")
 
321
  app.launch(
322
  show_api=False,
323
+ share=False,
324
+ allowed_paths=[absolute_image_dir],
325
+ # server_name="0.0.0.0"
326
  )