Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import requests
|
3 |
import time
|
|
|
4 |
import threading
|
5 |
import uuid
|
6 |
import base64
|
@@ -18,78 +19,142 @@ API_KEY = os.getenv("WAVESPEED_API_KEY")
|
|
18 |
if not API_KEY:
|
19 |
raise ValueError("WAVESPEED_API_KEY is not set in environment variables")
|
20 |
|
|
|
21 |
MODEL_URL = "TostAI/nsfw-text-detection-large"
|
22 |
-
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
except Exception as e:
|
28 |
-
raise RuntimeError(f"Failed to load safety model: {str(e)}")
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
_instances = {}
|
33 |
_lock = threading.Lock()
|
34 |
|
35 |
@classmethod
|
36 |
-
def
|
|
|
|
|
|
|
37 |
with cls._lock:
|
38 |
-
if
|
39 |
-
cls._instances[
|
40 |
-
|
41 |
-
'history': [],
|
42 |
-
'last_active': time.time()
|
43 |
-
}
|
44 |
-
return cls._instances[session_id]
|
45 |
|
46 |
@classmethod
|
47 |
-
def
|
|
|
48 |
with cls._lock:
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
|
58 |
-
class
|
59 |
|
60 |
def __init__(self):
|
61 |
-
self.clients = {}
|
62 |
self.lock = threading.Lock()
|
|
|
|
|
63 |
|
64 |
-
def
|
65 |
with self.lock:
|
66 |
-
|
67 |
-
if client_id not in self.clients:
|
68 |
-
self.clients[client_id] = {'count': 1, 'reset': now + 3600}
|
69 |
-
return True
|
70 |
-
if now > self.clients[client_id]['reset']:
|
71 |
-
self.clients[client_id] = {'count': 1, 'reset': now + 3600}
|
72 |
-
return True
|
73 |
-
if self.clients[client_id]['count'] >= 20:
|
74 |
-
return False
|
75 |
-
self.clients[client_id]['count'] += 1
|
76 |
-
return True
|
77 |
|
|
|
|
|
|
|
78 |
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
|
95 |
@torch.no_grad()
|
@@ -112,28 +177,48 @@ def decode_base64_to_image(base64_str):
|
|
112 |
return Image.open(io.BytesIO(image_data))
|
113 |
|
114 |
|
115 |
-
def generate_image(
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
try:
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
return
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
error_messages = []
|
139 |
if not image_file:
|
@@ -143,16 +228,27 @@ def generate_image(image_file,
|
|
143 |
if not prompt.strip():
|
144 |
error_messages.append("Prompt cannot be empty")
|
145 |
if error_messages:
|
146 |
-
|
147 |
-
|
|
|
148 |
return
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
try:
|
151 |
base64_image = image_to_base64(image_file)
|
152 |
input_image = decode_base64_to_image(base64_image)
|
153 |
except Exception as e:
|
154 |
-
|
155 |
-
yield
|
156 |
return
|
157 |
|
158 |
headers = {
|
@@ -178,7 +274,7 @@ def generate_image(image_file,
|
|
178 |
start_time = time.time()
|
179 |
|
180 |
for _ in range(60):
|
181 |
-
time.sleep(1)
|
182 |
resp = requests.get(result_url, headers=headers)
|
183 |
resp.raise_for_status()
|
184 |
|
@@ -188,25 +284,28 @@ def generate_image(image_file,
|
|
188 |
if status == "completed":
|
189 |
elapsed = time.time() - start_time
|
190 |
output_url = data["outputs"][0]
|
191 |
-
|
192 |
-
|
193 |
return
|
194 |
elif status == "failed":
|
195 |
raise Exception(data.get("error", "Unknown error"))
|
196 |
else:
|
197 |
-
|
|
|
198 |
|
199 |
raise Exception("Generation timed out")
|
200 |
|
201 |
except Exception as e:
|
202 |
-
|
203 |
-
yield
|
204 |
|
205 |
|
|
|
206 |
def cleanup_task():
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
210 |
|
211 |
|
212 |
# Store recent generations
|
@@ -238,14 +337,16 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
238 |
|
239 |
with gr.Row():
|
240 |
with gr.Column(scale=1):
|
241 |
-
prompt = gr.Textbox(label="Prompt",
|
242 |
-
placeholder="Please enter your prompt...",
|
243 |
-
lines=3)
|
244 |
image_file = gr.Image(label="Upload Image",
|
245 |
type="filepath",
|
246 |
-
sources=["upload"],
|
247 |
interactive=True,
|
248 |
-
image_mode="RGB"
|
|
|
|
|
|
|
|
|
|
|
249 |
seed = gr.Number(label="seed",
|
250 |
value=-1,
|
251 |
minimum=-1,
|
@@ -256,8 +357,8 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
256 |
value=True,
|
257 |
interactive=False)
|
258 |
with gr.Column(scale=1):
|
259 |
-
status = gr.Textbox(label="Status", elem_classes=["status-box"])
|
260 |
output_image = gr.Image(label="Generated Result")
|
|
|
261 |
output_url = gr.Textbox(label="Image URL",
|
262 |
interactive=True,
|
263 |
visible=False)
|
@@ -266,15 +367,15 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
266 |
examples=[
|
267 |
[
|
268 |
"Convert the image into Claymation style.",
|
269 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/
|
270 |
],
|
271 |
[
|
272 |
"Convert the image into Ghibli style.",
|
273 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/
|
274 |
],
|
275 |
[
|
276 |
-
"Add sunglasses to the face of the
|
277 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/
|
278 |
],
|
279 |
# [
|
280 |
# 'Convert the image into an ink sketch style.',
|
@@ -319,11 +420,16 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
319 |
inputs=[image_file, prompt, seed, session_id, enable_safety],
|
320 |
outputs=[status, output_image, output_url, recent_gallery],
|
321 |
api_name=False,
|
|
|
|
|
|
|
322 |
)
|
323 |
|
324 |
if __name__ == "__main__":
|
325 |
-
|
326 |
-
|
|
|
327 |
server_name="0.0.0.0",
|
|
|
328 |
share=False,
|
329 |
)
|
|
|
1 |
import os
|
2 |
import requests
|
3 |
import time
|
4 |
+
import functools
|
5 |
import threading
|
6 |
import uuid
|
7 |
import base64
|
|
|
19 |
if not API_KEY:
|
20 |
raise ValueError("WAVESPEED_API_KEY is not set in environment variables")
|
21 |
|
22 |
+
|
23 |
MODEL_URL = "TostAI/nsfw-text-detection-large"
|
24 |
+
TITLE = "πΌοΈπ Image Prompt Safety Classifier π‘οΈ"
|
25 |
+
DESCRIPTION = "β¨ Enter an image generation prompt to classify its safety level! β¨"
|
26 |
|
27 |
+
# Load model and tokenizer
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
|
29 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
|
|
|
|
|
30 |
|
31 |
+
# Define class names with emojis and detailed descriptions
|
32 |
+
CLASS_NAMES = {
|
33 |
+
0: "β
SAFE - This prompt is appropriate and harmless.",
|
34 |
+
1: "β οΈ QUESTIONABLE - This prompt may require further review.",
|
35 |
+
2: "π« UNSAFE - This prompt is likely to generate inappropriate content."
|
36 |
+
}
|
37 |
|
38 |
+
|
39 |
+
@functools.lru_cache(maxsize=128)
|
40 |
+
def classify_text(text):
|
41 |
+
inputs = tokenizer(text,
|
42 |
+
return_tensors="pt",
|
43 |
+
truncation=True,
|
44 |
+
padding=True,
|
45 |
+
max_length=1024)
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
outputs = model(**inputs)
|
49 |
+
|
50 |
+
logits = outputs.logits
|
51 |
+
predicted_class = torch.argmax(logits, dim=1).item()
|
52 |
+
|
53 |
+
return predicted_class, CLASS_NAMES[predicted_class]
|
54 |
+
|
55 |
+
|
56 |
+
class ClientManager:
|
57 |
_instances = {}
|
58 |
_lock = threading.Lock()
|
59 |
|
60 |
@classmethod
|
61 |
+
def get_manager(cls, client_id=None):
|
62 |
+
if not client_id:
|
63 |
+
client_id = str(uuid.uuid4())
|
64 |
+
|
65 |
with cls._lock:
|
66 |
+
if client_id not in cls._instances:
|
67 |
+
cls._instances[client_id] = ClientGenerationManager()
|
68 |
+
return cls._instances[client_id]
|
|
|
|
|
|
|
|
|
69 |
|
70 |
@classmethod
|
71 |
+
def cleanup_old_clients(cls, max_age=3600): # 1 hour default
|
72 |
+
current_time = time.time()
|
73 |
with cls._lock:
|
74 |
+
to_remove = []
|
75 |
+
for client_id, manager in cls._instances.items():
|
76 |
+
if (hasattr(manager, "last_activity")
|
77 |
+
and current_time - manager.last_activity > max_age):
|
78 |
+
to_remove.append(client_id)
|
79 |
+
|
80 |
+
for client_id in to_remove:
|
81 |
+
del cls._instances[client_id]
|
82 |
|
83 |
|
84 |
+
class ClientGenerationManager:
|
85 |
|
86 |
def __init__(self):
|
|
|
87 |
self.lock = threading.Lock()
|
88 |
+
self.last_activity = time.time()
|
89 |
+
self.request_timestamps = [] # Track timestamps of requests
|
90 |
|
91 |
+
def update_activity(self):
|
92 |
with self.lock:
|
93 |
+
self.last_activity = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
def add_request_timestamp(self):
|
96 |
+
with self.lock:
|
97 |
+
self.request_timestamps.append(time.time())
|
98 |
|
99 |
+
def has_exceeded_limit(self, limit=20):
|
100 |
+
with self.lock:
|
101 |
+
current_time = time.time()
|
102 |
+
# Filter timestamps to only include those within the last hour
|
103 |
+
self.request_timestamps = [
|
104 |
+
ts for ts in self.request_timestamps
|
105 |
+
if current_time - ts <= 3600
|
106 |
+
]
|
107 |
+
return len(self.request_timestamps) >= limit
|
108 |
|
109 |
|
110 |
+
class SessionManager:
|
111 |
+
_instances = {}
|
112 |
+
_lock = threading.Lock()
|
113 |
+
|
114 |
+
@classmethod
|
115 |
+
def get_manager(cls, session_id=None):
|
116 |
+
if session_id is None:
|
117 |
+
session_id = str(uuid.uuid4())
|
118 |
+
|
119 |
+
with cls._lock:
|
120 |
+
if session_id not in cls._instances:
|
121 |
+
cls._instances[session_id] = GenerationManager()
|
122 |
+
return session_id, cls._instances[session_id]
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def cleanup_old_sessions(cls, max_age=3600): # 1 hour default
|
126 |
+
current_time = time.time()
|
127 |
+
with cls._lock:
|
128 |
+
to_remove = []
|
129 |
+
for session_id, manager in cls._instances.items():
|
130 |
+
if (hasattr(manager, "last_activity")
|
131 |
+
and current_time - manager.last_activity > max_age):
|
132 |
+
to_remove.append(session_id)
|
133 |
+
|
134 |
+
for session_id in to_remove:
|
135 |
+
del cls._instances[session_id]
|
136 |
+
|
137 |
+
|
138 |
+
class GenerationManager:
|
139 |
+
|
140 |
+
def __init__(self):
|
141 |
+
self.last_activity = time.time()
|
142 |
+
self.request_timestamps = [] # Track timestamps of requests
|
143 |
+
|
144 |
+
def update_activity(self):
|
145 |
+
self.last_activity = time.time()
|
146 |
+
|
147 |
+
def add_request_timestamp(self):
|
148 |
+
self.request_timestamps.append(time.time())
|
149 |
+
|
150 |
+
def has_exceeded_limit(self,
|
151 |
+
limit=10): # Default limit: 10 requests per hour
|
152 |
+
current_time = time.time()
|
153 |
+
# Filter timestamps to only include those within the last hour
|
154 |
+
self.request_timestamps = [
|
155 |
+
ts for ts in self.request_timestamps if current_time - ts <= 3600
|
156 |
+
]
|
157 |
+
return len(self.request_timestamps) >= limit
|
158 |
|
159 |
|
160 |
@torch.no_grad()
|
|
|
177 |
return Image.open(io.BytesIO(image_data))
|
178 |
|
179 |
|
180 |
+
def generate_image(
|
181 |
+
image_file,
|
182 |
+
prompt,
|
183 |
+
seed,
|
184 |
+
session_id,
|
185 |
+
enable_safety_checker,
|
186 |
+
request: gr.Request,
|
187 |
+
):
|
188 |
try:
|
189 |
+
client_ip = request.client.host
|
190 |
+
x_forwarded_for = request.headers.get('x-forwarded-for')
|
191 |
+
if x_forwarded_for:
|
192 |
+
client_ip = x_forwarded_for
|
193 |
+
print(f"Client IP: {client_ip}")
|
194 |
+
client_generation_manager = ClientManager.get_manager(client_ip)
|
195 |
+
client_generation_manager.update_activity()
|
196 |
+
if client_generation_manager.has_exceeded_limit(limit=20):
|
197 |
+
error_message = "β Your network has exceeded the limit of 20 requests per hour. Please try again later."
|
198 |
+
yield error_message, None, "", None
|
199 |
+
return
|
200 |
|
201 |
+
client_generation_manager.add_request_timestamp()
|
202 |
+
"""Generate images with big status box during generation"""
|
203 |
+
# Get or create a session manager
|
204 |
+
session_id, manager = SessionManager.get_manager(session_id)
|
205 |
+
manager.update_activity()
|
206 |
+
|
207 |
+
# Check if the user has exceeded the request limit
|
208 |
+
if manager.has_exceeded_limit(
|
209 |
+
limit=10): # Set the limit to 10 requests per hour
|
210 |
+
error_message = "β You have exceeded the limit of 10 requests per hour. Please try again later."
|
211 |
+
yield error_message, None, "", None
|
212 |
return
|
213 |
|
214 |
+
# Add the current request timestamp
|
215 |
+
manager.add_request_timestamp()
|
216 |
+
|
217 |
+
if not prompt or prompt.strip() == "":
|
218 |
+
# Handle empty prompt case
|
219 |
+
error_message = "β οΈ Please enter a prompt first"
|
220 |
+
yield error_message, None, "", None
|
221 |
+
return
|
222 |
|
223 |
error_messages = []
|
224 |
if not image_file:
|
|
|
228 |
if not prompt.strip():
|
229 |
error_messages.append("Prompt cannot be empty")
|
230 |
if error_messages:
|
231 |
+
error_message = "β Input validation failed: " + ", ".join(
|
232 |
+
error_messages)
|
233 |
+
yield error_message, None, "", None
|
234 |
return
|
235 |
|
236 |
+
# Check if the prompt is safe
|
237 |
+
classification, message = classify_text(prompt)
|
238 |
+
if classification == 2: # UNSAFE
|
239 |
+
yield "β NSFW prompt detected", None, "", None
|
240 |
+
return
|
241 |
+
|
242 |
+
# Status message
|
243 |
+
status_message = f"π PROCESSING: '{prompt}'"
|
244 |
+
yield status_message, None, "", None
|
245 |
+
|
246 |
try:
|
247 |
base64_image = image_to_base64(image_file)
|
248 |
input_image = decode_base64_to_image(base64_image)
|
249 |
except Exception as e:
|
250 |
+
error_message = f"β File processing failed: {str(e)}"
|
251 |
+
yield error_message, None, "", None
|
252 |
return
|
253 |
|
254 |
headers = {
|
|
|
274 |
start_time = time.time()
|
275 |
|
276 |
for _ in range(60):
|
277 |
+
time.sleep(1.0)
|
278 |
resp = requests.get(result_url, headers=headers)
|
279 |
resp.raise_for_status()
|
280 |
|
|
|
284 |
if status == "completed":
|
285 |
elapsed = time.time() - start_time
|
286 |
output_url = data["outputs"][0]
|
287 |
+
yield f"π Generation successful! Time taken {elapsed:.1f}s", output_url, output_url, update_recent_gallery(
|
288 |
+
prompt, input_image, output_url)
|
289 |
return
|
290 |
elif status == "failed":
|
291 |
raise Exception(data.get("error", "Unknown error"))
|
292 |
else:
|
293 |
+
error_message = f"β³ Current status: {status.capitalize()}..."
|
294 |
+
yield error_message, None, "", None
|
295 |
|
296 |
raise Exception("Generation timed out")
|
297 |
|
298 |
except Exception as e:
|
299 |
+
error_message = f"β Generation failed: {str(e)}"
|
300 |
+
yield error_message, None, "", None
|
301 |
|
302 |
|
303 |
+
# Schedule periodic cleanup of old sessions
|
304 |
def cleanup_task():
|
305 |
+
SessionManager.cleanup_old_sessions()
|
306 |
+
ClientManager.cleanup_old_clients()
|
307 |
+
# Schedule the next cleanup
|
308 |
+
threading.Timer(3600, cleanup_task).start() # Run every hour
|
309 |
|
310 |
|
311 |
# Store recent generations
|
|
|
337 |
|
338 |
with gr.Row():
|
339 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
340 |
image_file = gr.Image(label="Upload Image",
|
341 |
type="filepath",
|
342 |
+
sources=["upload", "clipboard"],
|
343 |
interactive=True,
|
344 |
+
image_mode="RGB",
|
345 |
+
value="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png")
|
346 |
+
prompt = gr.Textbox(label="Prompt",
|
347 |
+
placeholder="Please enter your prompt...",
|
348 |
+
lines=3,
|
349 |
+
value="Convert the image into Claymation style.")
|
350 |
seed = gr.Number(label="seed",
|
351 |
value=-1,
|
352 |
minimum=-1,
|
|
|
357 |
value=True,
|
358 |
interactive=False)
|
359 |
with gr.Column(scale=1):
|
|
|
360 |
output_image = gr.Image(label="Generated Result")
|
361 |
+
status = gr.Textbox(label="Status", elem_classes=["status-box"])
|
362 |
output_url = gr.Textbox(label="Image URL",
|
363 |
interactive=True,
|
364 |
visible=False)
|
|
|
367 |
examples=[
|
368 |
[
|
369 |
"Convert the image into Claymation style.",
|
370 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png"
|
371 |
],
|
372 |
[
|
373 |
"Convert the image into Ghibli style.",
|
374 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
|
375 |
],
|
376 |
[
|
377 |
+
"Add sunglasses to the face of the statue.",
|
378 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg"
|
379 |
],
|
380 |
# [
|
381 |
# 'Convert the image into an ink sketch style.',
|
|
|
420 |
inputs=[image_file, prompt, seed, session_id, enable_safety],
|
421 |
outputs=[status, output_image, output_url, recent_gallery],
|
422 |
api_name=False,
|
423 |
+
max_batch_size=10,
|
424 |
+
concurrency_limit=20,
|
425 |
+
concurrency_id="generation",
|
426 |
)
|
427 |
|
428 |
if __name__ == "__main__":
|
429 |
+
# Start the cleanup task
|
430 |
+
cleanup_task()
|
431 |
+
app.queue(max_size=20).launch(
|
432 |
server_name="0.0.0.0",
|
433 |
+
max_threads=10,
|
434 |
share=False,
|
435 |
)
|