greendra commited on
Commit
6a3528b
·
verified ·
1 Parent(s): 2553412

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -60
app.py CHANGED
@@ -10,6 +10,7 @@ import json
10
  import uuid
11
  from urllib.parse import quote
12
  import traceback
 
13
 
14
  # Project by Nymbo
15
 
@@ -38,29 +39,86 @@ except OSError as e:
38
  absolute_image_dir = os.path.abspath(IMAGE_DIR)
39
  print(f"Absolute path for allowed_paths: {absolute_image_dir}")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Function to query the API and return the generated image and download link
42
  def query(prompt, negative_prompt="", steps=4, cfg_scale=0, seed=-1, width=1024, height=1024):
43
- # Removed sampler and strength as they are not explicitly used in the payload below
44
- # Note: If the API endpoint *does* support sampler/strength, add them back to the payload
45
  if not prompt or not prompt.strip():
46
- print("Empty prompt received.")
47
- # Return None for image and an informative message for the HTML component
48
  return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"
49
 
50
- key = random.randint(0, 999)
51
- print(f"\n--- Generation {key} Started ---")
52
 
53
  # Translation
54
  try:
55
  translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
56
  except Exception as e:
57
- translated_prompt = prompt # Fallback to original if translation fails
 
 
 
58
 
59
- # Add suffix to prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
61
- print(f'Generation {key} prompt: {final_prompt}')
62
 
63
- # Prepare the payload for the API call
64
  payload = {
65
  "inputs": final_prompt,
66
  "parameters": {
@@ -70,13 +128,16 @@ def query(prompt, negative_prompt="", steps=4, cfg_scale=0, seed=-1, width=1024,
70
  "seed": seed if seed != -1 else random.randint(1, 1000000000),
71
  "width": width,
72
  "height": height,
 
73
  }
74
  }
75
 
76
  # API Call Section
77
  try:
78
  if not headers:
79
- print("WARNING: Authorization header is missing (HF_READ_TOKEN not set?)")
 
 
80
  return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"
81
 
82
  response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
@@ -84,13 +145,19 @@ def query(prompt, negative_prompt="", steps=4, cfg_scale=0, seed=-1, width=1024,
84
 
85
  image_bytes = response.content
86
  if not image_bytes or len(image_bytes) < 100:
87
- print(f"Error: Received empty or very small response content (length: {len(image_bytes)}). Potential API issue.")
 
 
 
88
  return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"
89
 
90
  try:
91
  image = Image.open(io.BytesIO(image_bytes))
92
  except UnidentifiedImageError as img_err:
93
- print(f"Error: Could not identify or open image from API response bytes: {img_err}")
 
 
 
94
  return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"
95
 
96
  # --- Save image and create download link ---
@@ -101,86 +168,90 @@ def query(prompt, negative_prompt="", steps=4, cfg_scale=0, seed=-1, width=1024,
101
  try:
102
  image.save(save_path, "PNG")
103
 
104
- if os.path.exists(save_path):
105
- file_size = os.path.getsize(save_path)
106
- if file_size < 100:
107
- print(f"WARNING: Saved file {save_path} is very small ({file_size} bytes). May indicate an issue.")
108
- else:
109
- print(f"CRITICAL ERROR: File NOT found at {save_path} (Absolute: {absolute_save_path}) immediately after saving!")
110
- return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"
111
-
112
- # Determine space name (adjust logic if API_URL format differs)
113
- try:
114
- space_name = "greendra-flux-1-schnell-serverless"
115
- # A more robust way might involve getting the space ID from env vars if available
116
- except IndexError:
117
- print("WARNING: Could not reliably determine space name from API_URL. Using a default.")
118
- space_name = "unknown-flux-space" # Provide a fallback
119
-
120
- relative_file_url = f"/gradio_api/file={save_path}"
121
 
 
 
 
122
  encoded_file_url = quote(relative_file_url)
123
  arinteli_url = f"{ARINTELI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"
124
- print(f"{arinteli_url}")
125
 
126
- # Use Gradio's primary button style for the link
127
  download_html = (
128
- f'<div style="text-align: center; margin-top: 15px;">' # Added margin-top
129
  f'<a href="{arinteli_url}" target="_blank" class="gr-button gr-button-lg gr-button-primary">'
130
  f'Download Image'
131
  f'</a>'
132
  f'</div>'
133
  )
134
 
135
- print(f"--- Generation {key} Done ---")
 
 
 
136
  return image, download_html
137
 
138
  except (OSError, IOError) as save_err:
139
- print(f"CRITICAL ERROR: Failed to save image to {save_path} (Absolute: {absolute_save_path}): {save_err}")
 
 
 
 
140
  traceback.print_exc()
141
  return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file. Details: {save_err}</p>"
142
  except Exception as e:
143
- print(f"Error during link creation or unexpected save issue: {e}")
 
 
 
144
  traceback.print_exc()
145
  return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"
146
 
147
- # --- Exception Handling for API Call ---
148
  except requests.exceptions.Timeout:
149
- print(f"Error: Request timed out after {timeout} seconds.")
 
 
150
  return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
151
  except requests.exceptions.HTTPError as e:
152
  status_code = e.response.status_code
153
  error_text = e.response.text
 
154
  try:
155
  error_data = e.response.json()
156
- error_text = error_data.get('error', error_text)
157
- if isinstance(error_text, dict) and 'message' in error_text:
158
- error_text = error_text['message']
 
159
  except json.JSONDecodeError:
160
  pass
161
 
162
- print(f"Error: Failed API call. Status: {status_code}, Response: {error_text}")
 
 
 
163
 
 
164
  if status_code == 503:
165
- estimated_time = error_data.get("estimated_time") if 'error_data' in locals() and isinstance(error_data, dict) else None
166
- if estimated_time:
167
- error_message = f"Model is loading (503), please wait. Est. time: {estimated_time:.1f}s. Try again."
168
- else:
169
- error_message = f"Service unavailable (503). Model might be loading or down. Try again later."
170
- elif status_code == 400:
171
- error_message = f"Bad Request (400): Check parameters. API Error: {error_text}"
172
- elif status_code == 422:
173
- error_message = f"Validation Error (422): Input invalid. API Error: {error_text}"
174
- elif status_code == 401 or status_code == 403:
175
- error_message = f"Authorization Error ({status_code}): Check your API Token (HF_READ_TOKEN)."
176
- else:
177
- error_message = f"API Error: {status_code}. Details: {error_text}"
178
 
179
  return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
180
  except Exception as e:
181
- print(f"An unexpected error occurred: {e}")
 
 
 
182
  traceback.print_exc()
183
- return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred: {e}</p>"
184
 
185
 
186
  # CSS to style the app
@@ -215,9 +286,8 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
215
  with gr.Row():
216
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
217
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
218
- steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=8, step=1) # Default updated based on query function default
219
- cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1) # Label updated
220
- # Removed strength and sampler sliders as they are not passed to query
221
  seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")
222
 
223
  with gr.Row():
 
10
  import uuid
11
  from urllib.parse import quote
12
  import traceback
13
+ from openai import OpenAI, RateLimitError, APIConnectionError, AuthenticationError
14
 
15
  # Project by Nymbo
16
 
 
39
  absolute_image_dir = os.path.abspath(IMAGE_DIR)
40
  print(f"Absolute path for allowed_paths: {absolute_image_dir}")
41
 
42
+ # --- OpenAI Client ---
43
+ try:
44
+ openai_api_key = os.getenv("OPENAI_API_KEY")
45
+ if not openai_api_key:
46
+ print("WARNING: OPENAI_API_KEY environment variable not set or empty. Moderation will be skipped.")
47
+ openai_client = None
48
+ else:
49
+ openai_client = OpenAI()
50
+ print("OpenAI client initialized successfully.")
51
+ except Exception as e:
52
+ print(f"ERROR: Failed to initialize OpenAI client: {e}. Moderation will be skipped.")
53
+ openai_client = None
54
+
55
  # Function to query the API and return the generated image and download link
56
  def query(prompt, negative_prompt="", steps=4, cfg_scale=0, seed=-1, width=1024, height=1024):
57
+ # Basic Input Validation
 
58
  if not prompt or not prompt.strip():
59
+ print("WARNING: Empty prompt received.\n") # Add newline for separation
 
60
  return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"
61
 
62
+ # Store original prompt for logging
63
+ original_prompt = prompt
64
 
65
  # Translation
66
  try:
67
  translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
68
  except Exception as e:
69
+ print(f"WARNING: Translation failed. Using original prompt.")
70
+ print(f" Error: {e}")
71
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
72
+ translated_prompt = prompt
73
 
74
+ # --- OpenAI Moderation Check ---
75
+ if openai_client:
76
+ try:
77
+ mod_response = openai_client.moderations.create(
78
+ model="omni-moderation-latest",
79
+ input=translated_prompt
80
+ )
81
+ result = mod_response.results[0]
82
+
83
+ if result.categories.sexual_minors:
84
+ print("BLOCKED:")
85
+ print(f" Reason: sexual/minors")
86
+ print(f" Prompt: '{original_prompt}'")
87
+ print(f" Translated: '{translated_prompt}'\n") # Add newline
88
+ return None, "<p style='color: red; text-align: center;'>Prompt violates safety guidelines. Generation blocked.</p>"
89
+
90
+ except AuthenticationError:
91
+ print("BLOCKED:")
92
+ print(f" Reason: OpenAI Auth Error")
93
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
94
+ return None, "<p style='color: red; text-align: center;'>Safety check failed. Generation blocked.</p>"
95
+ except RateLimitError:
96
+ print("BLOCKED:")
97
+ print(f" Reason: OpenAI Rate Limit")
98
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
99
+ return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
100
+ except APIConnectionError as e:
101
+ print("BLOCKED:")
102
+ print(f" Reason: OpenAI Connection Error")
103
+ print(f" Prompt: '{original_prompt}'")
104
+ print(f" Error: {e}\n") # Add newline
105
+ return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
106
+ except Exception as e:
107
+ print("BLOCKED:")
108
+ print(f" Reason: OpenAI Unexpected Error")
109
+ print(f" Prompt: '{original_prompt}'")
110
+ print(f" Error: {e}\n") # Add newline
111
+ traceback.print_exc()
112
+ return None, "<p style='color: red; text-align: center;'>An unexpected error occurred during safety check. Generation blocked.</p>"
113
+ else:
114
+ print(f"WARNING: OpenAI client not available. Skipping moderation.")
115
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
116
+
117
+ # --- Proceed with Generation ---
118
+ # Add suffix (Ensure consistency with gradio1.py if desired, or keep model-specific suffixes)
119
  final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
 
120
 
121
+ # Prepare payload (Adjust parameters based on FLUX.1-schnell API requirements if different)
122
  payload = {
123
  "inputs": final_prompt,
124
  "parameters": {
 
128
  "seed": seed if seed != -1 else random.randint(1, 1000000000),
129
  "width": width,
130
  "height": height,
131
+ # Add any other specific parameters for FLUX.1-schnell if needed
132
  }
133
  }
134
 
135
  # API Call Section
136
  try:
137
  if not headers:
138
+ print("FAILED:")
139
+ print(f" Reason: HF Token Missing")
140
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
141
  return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"
142
 
143
  response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
 
145
 
146
  image_bytes = response.content
147
  if not image_bytes or len(image_bytes) < 100:
148
+ print("FAILED:")
149
+ print(f" Reason: Invalid Image Data (Empty/Small)")
150
+ print(f" Prompt: '{original_prompt}'")
151
+ print(f" Length: {len(image_bytes)}\n") # Add newline
152
  return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"
153
 
154
  try:
155
  image = Image.open(io.BytesIO(image_bytes))
156
  except UnidentifiedImageError as img_err:
157
+ print("FAILED:")
158
+ print(f" Reason: Image Processing Error")
159
+ print(f" Prompt: '{original_prompt}'")
160
+ print(f" Error: {img_err}\n") # Add newline
161
  return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"
162
 
163
  # --- Save image and create download link ---
 
168
  try:
169
  image.save(save_path, "PNG")
170
 
171
+ if not os.path.exists(save_path) or os.path.getsize(save_path) < 100:
172
+ print("FAILED:")
173
+ print(f" Reason: Image Save Verification Error")
174
+ print(f" Prompt: '{original_prompt}'")
175
+ print(f" Path: '{save_path}'\n") # Add newline
176
+ return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ # Use the correct space name for this Gradio app
179
+ space_name = "greendra-flux-1-schnell-serverless" # Updated space_name
180
+ relative_file_url = f"gradio_api/file={save_path}"
181
  encoded_file_url = quote(relative_file_url)
182
  arinteli_url = f"{ARINTELI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"
 
183
 
 
184
  download_html = (
185
+ f'<div style="text-align: center; margin-top: 15px;">' # Keep existing style
186
  f'<a href="{arinteli_url}" target="_blank" class="gr-button gr-button-lg gr-button-primary">'
187
  f'Download Image'
188
  f'</a>'
189
  f'</div>'
190
  )
191
 
192
+ # *** SUCCESS LOG *** (Updated format)
193
+ print("SUCCESS:")
194
+ print(f" Prompt: '{original_prompt}'")
195
+ print(f" URL: '{arinteli_url}'\n") # Add newline
196
  return image, download_html
197
 
198
  except (OSError, IOError) as save_err:
199
+ print("FAILED:")
200
+ print(f" Reason: Image Save IO Error")
201
+ print(f" Prompt: '{original_prompt}'")
202
+ print(f" Path: '{save_path}'")
203
+ print(f" Error: {save_err}\n") # Add newline
204
  traceback.print_exc()
205
  return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file. Details: {save_err}</p>"
206
  except Exception as e:
207
+ print("FAILED:")
208
+ print(f" Reason: Link Creation/Save Unexpected Error")
209
+ print(f" Prompt: '{original_prompt}'")
210
+ print(f" Error: {e}\n") # Add newline
211
  traceback.print_exc()
212
  return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"
213
 
214
+ # --- Exception Handling for API Call --- (Updated format)
215
  except requests.exceptions.Timeout:
216
+ print("FAILED:")
217
+ print(f" Reason: HF API Timeout")
218
+ print(f" Prompt: '{original_prompt}'\n") # Add newline
219
  return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
220
  except requests.exceptions.HTTPError as e:
221
  status_code = e.response.status_code
222
  error_text = e.response.text
223
+ error_data = {}
224
  try:
225
  error_data = e.response.json()
226
+ parsed_error = error_data.get('error', error_text)
227
+ if isinstance(parsed_error, dict) and 'message' in parsed_error: error_text = parsed_error['message']
228
+ elif isinstance(parsed_error, list): error_text = "; ".join(map(str, parsed_error))
229
+ else: error_text = str(parsed_error)
230
  except json.JSONDecodeError:
231
  pass
232
 
233
+ print("FAILED:")
234
+ print(f" Reason: HF API HTTP Error {status_code}")
235
+ print(f" Prompt: '{original_prompt}'")
236
+ print(f" Details: '{error_text[:200]}'\n") # Add newline
237
 
238
+ # User-facing messages (Keep consistent with previous version of this file)
239
  if status_code == 503:
240
+ estimated_time = error_data.get("estimated_time") if isinstance(error_data, dict) else None
241
+ error_message = f"Model is loading (503), please wait." + (f" Est. time: {estimated_time:.1f}s." if estimated_time else "") + " Try again."
242
+ elif status_code == 400: error_message = f"Bad Request (400): Check parameters. API Error: {error_text}"
243
+ elif status_code == 422: error_message = f"Validation Error (422): Input invalid. API Error: {error_text}"
244
+ elif status_code == 401 or status_code == 403: error_message = f"Authorization Error ({status_code}): Check your API Token (HF_READ_TOKEN)."
245
+ else: error_message = f"API Error: {status_code}. Details: {error_text}" # Adjusted generic message slightly
 
 
 
 
 
 
 
246
 
247
  return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
248
  except Exception as e:
249
+ print("FAILED:")
250
+ print(f" Reason: Unexpected Error During Generation")
251
+ print(f" Prompt: '{original_prompt}'")
252
+ print(f" Error: {e}\n") # Add newline
253
  traceback.print_exc()
254
+ return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred: {e}</p>" # Keep specific error for user
255
 
256
 
257
  # CSS to style the app
 
286
  with gr.Row():
287
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
288
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
289
+ steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=8, step=1)
290
+ cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1)
 
291
  seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")
292
 
293
  with gr.Row():