chengzeyi commited on
Commit
4f62561
Β·
verified Β·
1 Parent(s): 7ced770

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +198 -92
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import requests
3
  import time
 
4
  import threading
5
  import uuid
6
  import base64
@@ -18,78 +19,142 @@ API_KEY = os.getenv("WAVESPEED_API_KEY")
18
  if not API_KEY:
19
  raise ValueError("WAVESPEED_API_KEY is not set in environment variables")
20
 
 
21
  MODEL_URL = "TostAI/nsfw-text-detection-large"
22
- CLASS_NAMES = {0: "βœ… SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"}
 
23
 
24
- try:
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
26
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
27
- except Exception as e:
28
- raise RuntimeError(f"Failed to load safety model: {str(e)}")
29
 
 
 
 
 
 
 
30
 
31
- class SessionManager:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  _instances = {}
33
  _lock = threading.Lock()
34
 
35
  @classmethod
36
- def get_session(cls, session_id):
 
 
 
37
  with cls._lock:
38
- if session_id not in cls._instances:
39
- cls._instances[session_id] = {
40
- 'count': 0,
41
- 'history': [],
42
- 'last_active': time.time()
43
- }
44
- return cls._instances[session_id]
45
 
46
  @classmethod
47
- def cleanup_sessions(cls):
 
48
  with cls._lock:
49
- now = time.time()
50
- expired = [
51
- k for k, v in cls._instances.items()
52
- if now - v['last_active'] > 3600
53
- ]
54
- for k in expired:
55
- del cls._instances[k]
 
56
 
57
 
58
- class RateLimiter:
59
 
60
  def __init__(self):
61
- self.clients = {}
62
  self.lock = threading.Lock()
 
 
63
 
64
- def check(self, client_id):
65
  with self.lock:
66
- now = time.time()
67
- if client_id not in self.clients:
68
- self.clients[client_id] = {'count': 1, 'reset': now + 3600}
69
- return True
70
- if now > self.clients[client_id]['reset']:
71
- self.clients[client_id] = {'count': 1, 'reset': now + 3600}
72
- return True
73
- if self.clients[client_id]['count'] >= 20:
74
- return False
75
- self.clients[client_id]['count'] += 1
76
- return True
77
 
 
 
 
78
 
79
- session_manager = SessionManager()
80
- rate_limiter = RateLimiter()
 
 
 
 
 
 
 
81
 
82
 
83
- def create_error_image(message):
84
- img = Image.new("RGB", (512, 512), "#ffdddd")
85
- try:
86
- font = ImageFont.truetype("arial.ttf", 24)
87
- except:
88
- font = ImageFont.load_default()
89
- draw = ImageDraw.Draw(img)
90
- text = f"Error: {message[:60]}..." if len(message) > 60 else message
91
- draw.text((50, 200), text, fill="#ff0000", font=font)
92
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  @torch.no_grad()
@@ -112,28 +177,48 @@ def decode_base64_to_image(base64_str):
112
  return Image.open(io.BytesIO(image_data))
113
 
114
 
115
- def generate_image(image_file,
116
- prompt,
117
- seed,
118
- session_id,
119
- enable_safety_checker=True):
 
 
 
120
  try:
121
- if enable_safety_checker:
122
- safety_level = classify_prompt(prompt)
123
- if safety_level != 0:
124
- error_img = create_error_image(CLASS_NAMES[safety_level])
125
- yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img, "", None
126
- return
 
 
 
 
 
127
 
128
- if not rate_limiter.check(session_id):
129
- error_img = create_error_image(
130
- "Hourly limit exceeded (20 requests)")
131
- yield "❌ Too many requests, please try again later", error_img, "", None
 
 
 
 
 
 
 
132
  return
133
 
134
- session = session_manager.get_session(session_id)
135
- session['last_active'] = time.time()
136
- session['count'] += 1
 
 
 
 
 
137
 
138
  error_messages = []
139
  if not image_file:
@@ -143,16 +228,27 @@ def generate_image(image_file,
143
  if not prompt.strip():
144
  error_messages.append("Prompt cannot be empty")
145
  if error_messages:
146
- error_img = create_error_image(" | ".join(error_messages))
147
- yield "❌ Input validation failed", error_img, "", None
 
148
  return
149
 
 
 
 
 
 
 
 
 
 
 
150
  try:
151
  base64_image = image_to_base64(image_file)
152
  input_image = decode_base64_to_image(base64_image)
153
  except Exception as e:
154
- error_img = create_error_image(f"File processing failed: {str(e)}")
155
- yield "❌ File processing failed", error_img, "", None
156
  return
157
 
158
  headers = {
@@ -178,7 +274,7 @@ def generate_image(image_file,
178
  start_time = time.time()
179
 
180
  for _ in range(60):
181
- time.sleep(1)
182
  resp = requests.get(result_url, headers=headers)
183
  resp.raise_for_status()
184
 
@@ -188,25 +284,28 @@ def generate_image(image_file,
188
  if status == "completed":
189
  elapsed = time.time() - start_time
190
  output_url = data["outputs"][0]
191
- session["history"].append(output_url)
192
- yield f"πŸŽ‰ Generation successful! Time taken {elapsed:.1f}s", output_url, output_url, update_recent_gallery(prompt, input_image, output_url)
193
  return
194
  elif status == "failed":
195
  raise Exception(data.get("error", "Unknown error"))
196
  else:
197
- yield f"⏳ Current status: {status.capitalize()}...", None, None, None
 
198
 
199
  raise Exception("Generation timed out")
200
 
201
  except Exception as e:
202
- error_img = create_error_image(str(e))
203
- yield f"❌ Generation failed: {str(e)}", error_img, "", None
204
 
205
 
 
206
  def cleanup_task():
207
- while True:
208
- session_manager.cleanup_sessions()
209
- time.sleep(3600)
 
210
 
211
 
212
  # Store recent generations
@@ -238,14 +337,16 @@ with gr.Blocks(theme=gr.themes.Soft(),
238
 
239
  with gr.Row():
240
  with gr.Column(scale=1):
241
- prompt = gr.Textbox(label="Prompt",
242
- placeholder="Please enter your prompt...",
243
- lines=3)
244
  image_file = gr.Image(label="Upload Image",
245
  type="filepath",
246
- sources=["upload"],
247
  interactive=True,
248
- image_mode="RGB")
 
 
 
 
 
249
  seed = gr.Number(label="seed",
250
  value=-1,
251
  minimum=-1,
@@ -256,8 +357,8 @@ with gr.Blocks(theme=gr.themes.Soft(),
256
  value=True,
257
  interactive=False)
258
  with gr.Column(scale=1):
259
- status = gr.Textbox(label="Status", elem_classes=["status-box"])
260
  output_image = gr.Image(label="Generated Result")
 
261
  output_url = gr.Textbox(label="Image URL",
262
  interactive=True,
263
  visible=False)
@@ -266,15 +367,15 @@ with gr.Blocks(theme=gr.themes.Soft(),
266
  examples=[
267
  [
268
  "Convert the image into Claymation style.",
269
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
270
  ],
271
  [
272
  "Convert the image into Ghibli style.",
273
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg"
274
  ],
275
  [
276
- "Add sunglasses to the face of the girl.",
277
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png"
278
  ],
279
  # [
280
  # 'Convert the image into an ink sketch style.',
@@ -319,11 +420,16 @@ with gr.Blocks(theme=gr.themes.Soft(),
319
  inputs=[image_file, prompt, seed, session_id, enable_safety],
320
  outputs=[status, output_image, output_url, recent_gallery],
321
  api_name=False,
 
 
 
322
  )
323
 
324
  if __name__ == "__main__":
325
- threading.Thread(target=cleanup_task, daemon=True).start()
326
- app.queue(max_size=8).launch(
 
327
  server_name="0.0.0.0",
 
328
  share=False,
329
  )
 
1
  import os
2
  import requests
3
  import time
4
+ import functools
5
  import threading
6
  import uuid
7
  import base64
 
19
  if not API_KEY:
20
  raise ValueError("WAVESPEED_API_KEY is not set in environment variables")
21
 
22
+
23
  MODEL_URL = "TostAI/nsfw-text-detection-large"
24
+ TITLE = "πŸ–ΌοΈπŸ” Image Prompt Safety Classifier πŸ›‘οΈ"
25
+ DESCRIPTION = "✨ Enter an image generation prompt to classify its safety level! ✨"
26
 
27
+ # Load model and tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
29
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
 
 
30
 
31
+ # Define class names with emojis and detailed descriptions
32
+ CLASS_NAMES = {
33
+ 0: "βœ… SAFE - This prompt is appropriate and harmless.",
34
+ 1: "⚠️ QUESTIONABLE - This prompt may require further review.",
35
+ 2: "🚫 UNSAFE - This prompt is likely to generate inappropriate content."
36
+ }
37
 
38
+
39
+ @functools.lru_cache(maxsize=128)
40
+ def classify_text(text):
41
+ inputs = tokenizer(text,
42
+ return_tensors="pt",
43
+ truncation=True,
44
+ padding=True,
45
+ max_length=1024)
46
+
47
+ with torch.no_grad():
48
+ outputs = model(**inputs)
49
+
50
+ logits = outputs.logits
51
+ predicted_class = torch.argmax(logits, dim=1).item()
52
+
53
+ return predicted_class, CLASS_NAMES[predicted_class]
54
+
55
+
56
+ class ClientManager:
57
  _instances = {}
58
  _lock = threading.Lock()
59
 
60
  @classmethod
61
+ def get_manager(cls, client_id=None):
62
+ if not client_id:
63
+ client_id = str(uuid.uuid4())
64
+
65
  with cls._lock:
66
+ if client_id not in cls._instances:
67
+ cls._instances[client_id] = ClientGenerationManager()
68
+ return cls._instances[client_id]
 
 
 
 
69
 
70
  @classmethod
71
+ def cleanup_old_clients(cls, max_age=3600): # 1 hour default
72
+ current_time = time.time()
73
  with cls._lock:
74
+ to_remove = []
75
+ for client_id, manager in cls._instances.items():
76
+ if (hasattr(manager, "last_activity")
77
+ and current_time - manager.last_activity > max_age):
78
+ to_remove.append(client_id)
79
+
80
+ for client_id in to_remove:
81
+ del cls._instances[client_id]
82
 
83
 
84
+ class ClientGenerationManager:
85
 
86
  def __init__(self):
 
87
  self.lock = threading.Lock()
88
+ self.last_activity = time.time()
89
+ self.request_timestamps = [] # Track timestamps of requests
90
 
91
+ def update_activity(self):
92
  with self.lock:
93
+ self.last_activity = time.time()
 
 
 
 
 
 
 
 
 
 
94
 
95
+ def add_request_timestamp(self):
96
+ with self.lock:
97
+ self.request_timestamps.append(time.time())
98
 
99
+ def has_exceeded_limit(self, limit=20):
100
+ with self.lock:
101
+ current_time = time.time()
102
+ # Filter timestamps to only include those within the last hour
103
+ self.request_timestamps = [
104
+ ts for ts in self.request_timestamps
105
+ if current_time - ts <= 3600
106
+ ]
107
+ return len(self.request_timestamps) >= limit
108
 
109
 
110
+ class SessionManager:
111
+ _instances = {}
112
+ _lock = threading.Lock()
113
+
114
+ @classmethod
115
+ def get_manager(cls, session_id=None):
116
+ if session_id is None:
117
+ session_id = str(uuid.uuid4())
118
+
119
+ with cls._lock:
120
+ if session_id not in cls._instances:
121
+ cls._instances[session_id] = GenerationManager()
122
+ return session_id, cls._instances[session_id]
123
+
124
+ @classmethod
125
+ def cleanup_old_sessions(cls, max_age=3600): # 1 hour default
126
+ current_time = time.time()
127
+ with cls._lock:
128
+ to_remove = []
129
+ for session_id, manager in cls._instances.items():
130
+ if (hasattr(manager, "last_activity")
131
+ and current_time - manager.last_activity > max_age):
132
+ to_remove.append(session_id)
133
+
134
+ for session_id in to_remove:
135
+ del cls._instances[session_id]
136
+
137
+
138
+ class GenerationManager:
139
+
140
+ def __init__(self):
141
+ self.last_activity = time.time()
142
+ self.request_timestamps = [] # Track timestamps of requests
143
+
144
+ def update_activity(self):
145
+ self.last_activity = time.time()
146
+
147
+ def add_request_timestamp(self):
148
+ self.request_timestamps.append(time.time())
149
+
150
+ def has_exceeded_limit(self,
151
+ limit=10): # Default limit: 10 requests per hour
152
+ current_time = time.time()
153
+ # Filter timestamps to only include those within the last hour
154
+ self.request_timestamps = [
155
+ ts for ts in self.request_timestamps if current_time - ts <= 3600
156
+ ]
157
+ return len(self.request_timestamps) >= limit
158
 
159
 
160
  @torch.no_grad()
 
177
  return Image.open(io.BytesIO(image_data))
178
 
179
 
180
+ def generate_image(
181
+ image_file,
182
+ prompt,
183
+ seed,
184
+ session_id,
185
+ enable_safety_checker,
186
+ request: gr.Request,
187
+ ):
188
  try:
189
+ client_ip = request.client.host
190
+ x_forwarded_for = request.headers.get('x-forwarded-for')
191
+ if x_forwarded_for:
192
+ client_ip = x_forwarded_for
193
+ print(f"Client IP: {client_ip}")
194
+ client_generation_manager = ClientManager.get_manager(client_ip)
195
+ client_generation_manager.update_activity()
196
+ if client_generation_manager.has_exceeded_limit(limit=20):
197
+ error_message = "❌ Your network has exceeded the limit of 20 requests per hour. Please try again later."
198
+ yield error_message, None, "", None
199
+ return
200
 
201
+ client_generation_manager.add_request_timestamp()
202
+ """Generate images with big status box during generation"""
203
+ # Get or create a session manager
204
+ session_id, manager = SessionManager.get_manager(session_id)
205
+ manager.update_activity()
206
+
207
+ # Check if the user has exceeded the request limit
208
+ if manager.has_exceeded_limit(
209
+ limit=10): # Set the limit to 10 requests per hour
210
+ error_message = "❌ You have exceeded the limit of 10 requests per hour. Please try again later."
211
+ yield error_message, None, "", None
212
  return
213
 
214
+ # Add the current request timestamp
215
+ manager.add_request_timestamp()
216
+
217
+ if not prompt or prompt.strip() == "":
218
+ # Handle empty prompt case
219
+ error_message = "⚠️ Please enter a prompt first"
220
+ yield error_message, None, "", None
221
+ return
222
 
223
  error_messages = []
224
  if not image_file:
 
228
  if not prompt.strip():
229
  error_messages.append("Prompt cannot be empty")
230
  if error_messages:
231
+ error_message = "❌ Input validation failed: " + ", ".join(
232
+ error_messages)
233
+ yield error_message, None, "", None
234
  return
235
 
236
+ # Check if the prompt is safe
237
+ classification, message = classify_text(prompt)
238
+ if classification == 2: # UNSAFE
239
+ yield "❌ NSFW prompt detected", None, "", None
240
+ return
241
+
242
+ # Status message
243
+ status_message = f"πŸ”„ PROCESSING: '{prompt}'"
244
+ yield status_message, None, "", None
245
+
246
  try:
247
  base64_image = image_to_base64(image_file)
248
  input_image = decode_base64_to_image(base64_image)
249
  except Exception as e:
250
+ error_message = f"❌ File processing failed: {str(e)}"
251
+ yield error_message, None, "", None
252
  return
253
 
254
  headers = {
 
274
  start_time = time.time()
275
 
276
  for _ in range(60):
277
+ time.sleep(1.0)
278
  resp = requests.get(result_url, headers=headers)
279
  resp.raise_for_status()
280
 
 
284
  if status == "completed":
285
  elapsed = time.time() - start_time
286
  output_url = data["outputs"][0]
287
+ yield f"πŸŽ‰ Generation successful! Time taken {elapsed:.1f}s", output_url, output_url, update_recent_gallery(
288
+ prompt, input_image, output_url)
289
  return
290
  elif status == "failed":
291
  raise Exception(data.get("error", "Unknown error"))
292
  else:
293
+ error_message = f"⏳ Current status: {status.capitalize()}..."
294
+ yield error_message, None, "", None
295
 
296
  raise Exception("Generation timed out")
297
 
298
  except Exception as e:
299
+ error_message = f"❌ Generation failed: {str(e)}"
300
+ yield error_message, None, "", None
301
 
302
 
303
+ # Schedule periodic cleanup of old sessions
304
  def cleanup_task():
305
+ SessionManager.cleanup_old_sessions()
306
+ ClientManager.cleanup_old_clients()
307
+ # Schedule the next cleanup
308
+ threading.Timer(3600, cleanup_task).start() # Run every hour
309
 
310
 
311
  # Store recent generations
 
337
 
338
  with gr.Row():
339
  with gr.Column(scale=1):
 
 
 
340
  image_file = gr.Image(label="Upload Image",
341
  type="filepath",
342
+ sources=["upload", "clipboard"],
343
  interactive=True,
344
+ image_mode="RGB",
345
+ value="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png")
346
+ prompt = gr.Textbox(label="Prompt",
347
+ placeholder="Please enter your prompt...",
348
+ lines=3,
349
+ value="Convert the image into Claymation style.")
350
  seed = gr.Number(label="seed",
351
  value=-1,
352
  minimum=-1,
 
357
  value=True,
358
  interactive=False)
359
  with gr.Column(scale=1):
 
360
  output_image = gr.Image(label="Generated Result")
361
+ status = gr.Textbox(label="Status", elem_classes=["status-box"])
362
  output_url = gr.Textbox(label="Image URL",
363
  interactive=True,
364
  visible=False)
 
367
  examples=[
368
  [
369
  "Convert the image into Claymation style.",
370
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png"
371
  ],
372
  [
373
  "Convert the image into Ghibli style.",
374
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
375
  ],
376
  [
377
+ "Add sunglasses to the face of the statue.",
378
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg"
379
  ],
380
  # [
381
  # 'Convert the image into an ink sketch style.',
 
420
  inputs=[image_file, prompt, seed, session_id, enable_safety],
421
  outputs=[status, output_image, output_url, recent_gallery],
422
  api_name=False,
423
+ max_batch_size=10,
424
+ concurrency_limit=20,
425
+ concurrency_id="generation",
426
  )
427
 
428
  if __name__ == "__main__":
429
+ # Start the cleanup task
430
+ cleanup_task()
431
+ app.queue(max_size=20).launch(
432
  server_name="0.0.0.0",
433
+ max_threads=10,
434
  share=False,
435
  )