jiandan1998 commited on
Commit
9a748ec
·
verified ·
1 Parent(s): 0f5b61b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import time
5
+ import random
6
+ import base64
7
+ import uuid
8
+ import threading
9
+ from pathlib import Path
10
+ from dotenv import load_dotenv
11
+ import gradio as gr
12
+ import torch
13
+ import logging
14
+ from PIL import Image, ImageDraw, ImageFont
15
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
16
+
17
+ load_dotenv()
18
+
19
+ MODEL_URL = "TostAI/nsfw-text-detection-large"
20
+ CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"}
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
22
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
23
+
24
+ class SessionManager:
25
+ _instances = {}
26
+ _lock = threading.Lock()
27
+
28
+ @classmethod
29
+ def get_session(cls, session_id):
30
+ with cls._lock:
31
+ if session_id not in cls._instances:
32
+ cls._instances[session_id] = {
33
+ 'count': 0,
34
+ 'history': [],
35
+ 'last_active': time.time()
36
+ }
37
+ return cls._instances[session_id]
38
+
39
+ @classmethod
40
+ def cleanup_sessions(cls):
41
+ with cls._lock:
42
+ now = time.time()
43
+ expired = [k for k, v in cls._instances.items() if now - v['last_active'] > 3600]
44
+ for k in expired:
45
+ del cls._instances[k]
46
+
47
+ class RateLimiter:
48
+ def __init__(self):
49
+ self.clients = {}
50
+ self.lock = threading.Lock()
51
+
52
+ def check(self, client_id):
53
+ with self.lock:
54
+ now = time.time()
55
+ if client_id not in self.clients:
56
+ self.clients[client_id] = {'count': 1, 'reset': now + 3600}
57
+ return True
58
+ if now > self.clients[client_id]['reset']:
59
+ self.clients[client_id] = {'count': 1, 'reset': now + 3600}
60
+ return True
61
+ if self.clients[client_id]['count'] >= 20:
62
+ return False
63
+ self.clients[client_id]['count'] += 1
64
+ return True
65
+
66
+ session_manager = SessionManager()
67
+ rate_limiter = RateLimiter()
68
+
69
+ def create_error_image(message):
70
+ img = Image.new("RGB", (832, 480), "#ffdddd")
71
+ try:
72
+ font = ImageFont.truetype("arial.ttf", 24)
73
+ except:
74
+ font = ImageFont.load_default()
75
+ draw = ImageDraw.Draw(img)
76
+ text = f"Error: {message[:60]}..." if len(message) > 60 else message
77
+ draw.text((50, 200), text, fill="#ff0000", font=font)
78
+ img.save("error.jpg")
79
+ return "error.jpg"
80
+
81
+ def classify_prompt(prompt):
82
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
83
+ with torch.no_grad():
84
+ outputs = model(**inputs)
85
+ return torch.argmax(outputs.logits).item()
86
+
87
+ def image_to_base64(file_path):
88
+ try:
89
+ with open(file_path, "rb") as image_file:
90
+ raw_data = image_file.read()
91
+ encoded = base64.b64encode(raw_data)
92
+ missing_padding = len(encoded) % 4
93
+ if missing_padding:
94
+ encoded += b'=' * (4 - missing_padding)
95
+ return encoded.decode('utf-8')
96
+ except Exception as e:
97
+ raise ValueError(f"Base64编码失败: {str(e)}")
98
+
99
+ def generate_video(
100
+ context_scale,
101
+ enable_safety_checker,
102
+ flow_shift,
103
+ guidance_scale,
104
+ images,
105
+ negative_prompt,
106
+ num_inference_steps,
107
+ prompt,
108
+ seed,
109
+ size,
110
+ task,
111
+ video,
112
+ session_id,
113
+ ):
114
+
115
+ safety_level = classify_prompt(prompt)
116
+ if safety_level != 0:
117
+ error_img = create_error_image(CLASS_NAMES[safety_level])
118
+ yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img
119
+ return
120
+
121
+ if not rate_limiter.check(session_id):
122
+ error_img = create_error_image("每小时限制20次请求")
123
+ yield "❌ 请求过于频繁,请稍后再试", error_img
124
+ return
125
+
126
+ session = session_manager.get_session(session_id)
127
+ session['last_active'] = time.time()
128
+ session['count'] += 1
129
+
130
+ API_KEY = os.getenv("WAVESPEED_API_KEY")
131
+ if not API_KEY:
132
+ error_img = create_error_image("API密钥缺失")
133
+ yield "❌ Error: Missing API Key", error_img
134
+ return
135
+
136
+ try:
137
+ if not images or len(images) < 2:
138
+ raise ValueError("需要上传至少两张图片")
139
+
140
+ base64_images = []
141
+ for img_path in images[:2]:
142
+ base64_img = image_to_base64(img_path)
143
+ base64_images.append(base64_img)
144
+
145
+ except Exception as e:
146
+ error_img = create_error_image(str(e))
147
+ yield f"❌ 文件处理失败: {str(e)}", error_img
148
+ return
149
+
150
+ video_payload = ""
151
+ if video is not None:
152
+ if isinstance(video, (list, tuple)):
153
+ video_payload = video[0] if video else ""
154
+ else:
155
+ video_payload = video
156
+
157
+ payload = {
158
+ "context_scale": context_scale,
159
+ "enable_fast_mode": False,
160
+ "enable_safety_checker": enable_safety_checker,
161
+ "flow_shift": flow_shift,
162
+ "guidance_scale": guidance_scale,
163
+ "images": base64_images,
164
+ "negative_prompt": negative_prompt,
165
+ "num_inference_steps": num_inference_steps,
166
+ "prompt": prompt,
167
+ "seed": seed if seed != -1 else random.randint(0, 999999),
168
+ "size": size,
169
+ "task": task,
170
+ "video": str(video_payload) if video_payload else "",
171
+ }
172
+
173
+ logging.debug(f"API请求payload: {json.dumps(payload, indent=2)}")
174
+
175
+ headers = {
176
+ "Content-Type": "application/json",
177
+ "Authorization": f"Bearer {API_KEY}",
178
+ }
179
+
180
+
181
+ try:
182
+ response = requests.post(
183
+ "https://api.wavespeed.ai/api/v2/wavespeed-ai/wan-2.1-14b-vace",
184
+ headers=headers,
185
+ data=json.dumps(payload)
186
+ )
187
+
188
+ if response.status_code != 200:
189
+ error_img = create_error_image(response.text)
190
+ yield f"❌ API错误 ({response.status_code}): {response.text}", error_img
191
+ return
192
+
193
+ request_id = response.json()["data"]["id"]
194
+ yield f"✅ 任务已提交 (ID: {request_id})", None
195
+ except Exception as e:
196
+ error_img = create_error_image(str(e))
197
+ yield f"❌ 连接错误: {str(e)}", error_img
198
+ return
199
+
200
+ result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
201
+ start_time = time.time()
202
+
203
+ while True:
204
+ time.sleep(0.5)
205
+ try:
206
+ response = requests.get(result_url, headers=headers)
207
+ if response.status_code != 200:
208
+ error_img = create_error_image(response.text)
209
+ yield f"❌ 轮询错误 ({response.status_code}): {response.text}", error_img
210
+ return
211
+
212
+ data = response.json()["data"]
213
+ status = data["status"]
214
+
215
+ if status == "completed":
216
+ elapsed = time.time() - start_time
217
+ video_url = data['outputs'][0]
218
+ session["history"].append(video_url)
219
+ yield (f"🎉 完成! 耗时 {elapsed:.1f}秒\n"
220
+ f"下载链接: {video_url}"), video_url
221
+ return
222
+
223
+ elif status == "failed":
224
+ error_img = create_error_image(data.get('error', '未知错误'))
225
+ yield f"❌ 任务失败: {data.get('error', '未知错误')}", error_img
226
+ return
227
+
228
+ else:
229
+ yield f"⏳ 状态: {status.capitalize()}...", None
230
+
231
+ except Exception as e:
232
+ error_img = create_error_image(str(e))
233
+ yield f"❌ 轮询失败: {str(e)}", error_img
234
+ return
235
+
236
+ def cleanup_task():
237
+ while True:
238
+ session_manager.cleanup_sessions()
239
+ time.sleep(3600)
240
+
241
+ with gr.Blocks(
242
+ theme=gr.themes.Soft(),
243
+ css="""
244
+ .video-preview { max-width: 600px !important; }
245
+ .status-box { padding: 10px; border-radius: 5px; margin: 5px; }
246
+ .safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
247
+ .warning { background: #fff3e0; border: 1px solid #ffcc80; }
248
+ .error { background: #ffebee; border: 1px solid #ef9a9a; }
249
+ """
250
+ ) as app:
251
+
252
+ session_id = gr.State(str(uuid.uuid4()))
253
+
254
+ gr.Markdown("# 🌊Wan-2.1-14B-Vace Run On [WaveSpeedAI](https://wavespeed.ai/)")
255
+ gr.Markdown("""VACE is an all-in-one model designed for video creation and editing. It encompasses various tasks, including reference-to-video generation (R2V), video-to-video editing (V2V), and masked video-to-video editing (MV2V), allowing users to compose these tasks freely. This functionality enables users to explore diverse possibilities and streamlines their workflows effectively, offering a range of capabilities, such as Move-Anything, Swap-Anything, Reference-Anything, Expand-Anything, Animate-Anything, and more.""")
256
+
257
+ with gr.Row():
258
+ with gr.Column(scale=1):
259
+ images = gr.File(label="upload image", file_count="multiple", file_types=["image"], type="filepath", elem_id="image-uploader")
260
+ video = gr.Video(label="Input Video", format="mp4", sources=["upload"])
261
+ prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Prompt...")
262
+ negative_prompt = gr.Textbox(label="Negative Prompt", lines=2)
263
+ size = gr.Dropdown(["832*480", "480*832"], value="832*480", label="Size")
264
+ context_scale = gr.Slider(0, 2, value=1, step=0.1, label="Context Scale")
265
+ num_inference_steps = gr.Slider(1, 100, value=20, step=1, label="Inference Steps")
266
+ task = gr.Dropdown(["depth", "pose"], value="depth", label="Task")
267
+ seed = gr.Number(-1, label="Seed")
268
+ random_seed_btn = gr.Button("Random🎲Seed", variant="secondary")
269
+ guidance = gr.Slider(1, 20, value=7.5, step=0.1, label="Guidance_Scale")
270
+ flow_shift = gr.Slider(1, 20, value=16, step=1, label="Shift")
271
+ enable_safety_checker = gr.Checkbox(True, label="Enable Safety Checker", interactive=True)
272
+ with gr.Column(scale=1):
273
+ video_output = gr.Video(label="Video Output", format="mp4", interactive=False, elem_classes=["video-preview"])
274
+ generate_btn = gr.Button("Generate", variant="primary")
275
+ status_output = gr.Textbox(label="status", interactive=False, lines=4)
276
+ gr.Examples(
277
+ examples=[
278
+ [
279
+ "The elegant lady carefully selects bags in the boutique, and she shows the charm of a mature woman in a black slim dress with a pearl necklace, as well as her pretty face. Holding a vintage-inspired blue leather half-moon handbag, she is carefully observing its craftsmanship and texture. The interior of the store is a haven of sophistication and luxury. Soft, ambient lighting casts a warm glow over the polished wooden floors",
280
+ [
281
+ "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413751234102420_md9ywspl.png",
282
+ "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413586520964413_7bkgc9ol.png"
283
+ ]
284
+ ]
285
+ ],
286
+ inputs=[prompt, images],
287
+ )
288
+
289
+ random_seed_btn.click(
290
+ fn=lambda: random.randint(0, 999999),
291
+ outputs=seed
292
+ )
293
+
294
+ generate_btn.click(
295
+ generate_video,
296
+ inputs=[
297
+ context_scale,
298
+ enable_safety_checker,
299
+ flow_shift,
300
+ guidance,
301
+ images,
302
+ negative_prompt,
303
+ num_inference_steps,
304
+ prompt,
305
+ seed,
306
+ size,
307
+ task,
308
+ video,
309
+ session_id,
310
+ ],
311
+ outputs=[status_output, video_output]
312
+ )
313
+
314
+ logging.basicConfig(
315
+ level=logging.DEBUG,
316
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
317
+ handlers=[
318
+ logging.FileHandler("gradio_app.log"),
319
+ logging.StreamHandler()
320
+ ]
321
+ )
322
+
323
+ gradio_logger = logging.getLogger("gradio")
324
+ gradio_logger.setLevel(logging.INFO)
325
+
326
+ if __name__ == "__main__":
327
+ threading.Thread(target=cleanup_task, daemon=True).start()
328
+ app.queue(max_size=4).launch(
329
+ server_name="0.0.0.0",
330
+ max_threads=16,
331
+ share=False
332
+ )