jiandan1998 commited on
Commit
0f5b61b
·
verified ·
1 Parent(s): e73a514

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -324
app.py DELETED
@@ -1,324 +0,0 @@
1
- import os
2
- import requests
3
- import json
4
- import time
5
- import threading
6
- import uuid
7
- import base64
8
- from pathlib import Path
9
- from dotenv import load_dotenv
10
- import gradio as gr
11
- import random
12
- import torch
13
- from PIL import Image, ImageDraw, ImageFont
14
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
- from functools import lru_cache
16
-
17
- load_dotenv()
18
-
19
- MODEL_URL = "TostAI/nsfw-text-detection-large"
20
- CLASS_NAMES = {
21
- 0: "✅ SAFE",
22
- 1: "⚠️ QUESTIONABLE",
23
- 2: "🚫 UNSAFE"
24
- }
25
-
26
- tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
27
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
28
-
29
- class SessionManager:
30
- _instances = {}
31
- _lock = threading.Lock()
32
-
33
- @classmethod
34
- def get_session(cls, session_id):
35
- with cls._lock:
36
- if session_id not in cls._instances:
37
- cls._instances[session_id] = {
38
- 'count': 0,
39
- 'history': [],
40
- 'last_active': time.time()
41
- }
42
- return cls._instances[session_id]
43
-
44
- @classmethod
45
- def cleanup_sessions(cls):
46
- with cls._lock:
47
- now = time.time()
48
- expired = [k for k, v in cls._instances.items() if now - v['last_active'] > 3600]
49
- for k in expired:
50
- del cls._instances[k]
51
-
52
- class RateLimiter:
53
- def __init__(self):
54
- self.clients = {}
55
- self.lock = threading.Lock()
56
-
57
- def check(self, client_id):
58
- with self.lock:
59
- now = time.time()
60
- if client_id not in self.clients:
61
- self.clients[client_id] = {'count': 1, 'reset': now + 3600}
62
- return True
63
-
64
- if now > self.clients[client_id]['reset']:
65
- self.clients[client_id] = {'count': 1, 'reset': now + 3600}
66
- return True
67
-
68
- if self.clients[client_id]['count'] >= 20:
69
- return False
70
-
71
- self.clients[client_id]['count'] += 1
72
- return True
73
-
74
- session_manager = SessionManager()
75
- rate_limiter = RateLimiter()
76
-
77
- def image_to_base64(file_path):
78
- try:
79
- with open(file_path, "rb") as f:
80
- img_data = f.read()
81
- if len(img_data) == 0:
82
- raise ValueError("空文件")
83
-
84
- encoded = base64.urlsafe_b64encode(img_data)
85
- missing_padding = len(encoded) % 4
86
- if missing_padding:
87
- encoded += b'=' * (4 - missing_padding)
88
-
89
- ext = Path(file_path).suffix.lower()[1:]
90
- mime_map = {'jpg':'jpeg','jpeg':'jpeg','png':'png','webp':'webp','gif':'gif'}
91
- mime = mime_map.get(ext, 'jpeg')
92
- return f"data:image/{mime};base64,{encoded.decode()}"
93
- except Exception as e:
94
- raise ValueError(f"Base64 Error: {str(e)}")
95
-
96
- def create_error_image(message):
97
- img = Image.new("RGB", (832, 480), "#ffdddd")
98
- try:
99
- font = ImageFont.truetype("arial.ttf", 24)
100
- except:
101
- font = ImageFont.load_default()
102
-
103
- draw = ImageDraw.Draw(img)
104
- text = f"Error: {message[:60]}..." if len(message) > 60 else message
105
- draw.text((50, 200), text, fill="#ff0000", font=font)
106
- img.save("error.jpg")
107
- return "error.jpg"
108
-
109
- @lru_cache(maxsize=100)
110
- def classify_prompt(prompt):
111
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
112
- with torch.no_grad():
113
- outputs = model(**inputs)
114
- return torch.argmax(outputs.logits).item()
115
-
116
- def generate_video(
117
- image_files,
118
- prompt,
119
- duration,
120
- enable_safety,
121
- flow_shift,
122
- guidance,
123
- negative_prompt,
124
- steps,
125
- seed,
126
- size,
127
- session_id
128
- ):
129
-
130
- if len(image_files) != 2:
131
- error_img = create_error_image("upload 2 images")
132
- yield "❌ error: upload 2 images", error_img
133
- return
134
-
135
- safety_level = classify_prompt(prompt)
136
- if safety_level != 0:
137
- error_img = create_error_image(CLASS_NAMES[safety_level])
138
- yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img
139
- return
140
-
141
- if not rate_limiter.check(session_id):
142
- error_img = create_error_image("Hourly limit exceeded (20 requests)")
143
- yield "❌ 请求过于频繁,请稍后再试", error_img
144
- return
145
-
146
- session = session_manager.get_session(session_id)
147
- session['last_active'] = time.time()
148
- session['count'] += 1
149
-
150
- try:
151
- api_key = os.getenv("WAVESPEED_API_KEY")
152
- if not api_key:
153
- raise ValueError("API key missing")
154
-
155
- base64_images = [image_to_base64(img) for img in image_files]
156
-
157
- headers = {
158
- "Authorization": f"Bearer {api_key}",
159
- "Content-Type": "application/json"
160
- }
161
-
162
- payload = {
163
- "seed": seed if seed != -1 else random.randint(0, 999999),
164
- "size": size.replace(" ", ""),
165
- "images": base64_images,
166
- "prompt": prompt,
167
- "flow_shift": flow_shift,
168
- "context_scale": 1,
169
- "guidance_scale": guidance,
170
- "negative_prompt": negative_prompt,
171
- "num_inference_steps": steps,
172
- "enable_safety_checker": enable_safety,
173
- "model_id": "wavespeed-ai/wan-2.1-14b-vace"
174
- }
175
-
176
- response = requests.post(
177
- "https://api.wavespeed.ai/api/v3/wavespeed-ai/wan-2.1-14b-vace",
178
- headers=headers,
179
- json=payload
180
- )
181
-
182
- if response.status_code != 200:
183
- raise Exception(f"API Error {response.status_code}: {response.text}")
184
-
185
- requestId = response.json()["data"]["id"]
186
- yield f"✅ 任务已提交 (ID: {requestId})", None
187
-
188
- except Exception as e:
189
- error_img = create_error_image(str(e))
190
- yield f"❌ 提交失败: {str(e)}", error_img
191
- return
192
-
193
- result_url = f"https://api.wavespeed.ai/api/v3/predictions/{requestId}/result"
194
- start_time = time.time()
195
-
196
- while True:
197
- time.sleep(1)
198
- try:
199
- resp = requests.get(result_url, headers=headers)
200
- if resp.status_code != 200:
201
- raise Exception(f"状态查询失败: {resp.text}")
202
-
203
- data = resp.json()["data"]
204
- status = data["status"]
205
-
206
- if status == "completed":
207
- elapsed = time.time() - start_time
208
- video_url = data["outputs"][0]
209
- session["history"].append(video_url)
210
- yield f"🎉 生成成功! 耗时 {elapsed:.1f}s", video_url
211
- return
212
-
213
- elif status == "failed":
214
- raise Exception(data.get("error", "Unknown error"))
215
-
216
- else:
217
- yield f"⏳ 当前状态: {status.capitalize()}...", None
218
-
219
- except Exception as e:
220
- error_img = create_error_image(str(e))
221
- yield f"❌ 生成失败: {str(e)}", error_img
222
- return
223
-
224
- def cleanup_task():
225
- while True:
226
- session_manager.cleanup_sessions()
227
- time.sleep(3600)
228
-
229
- with gr.Blocks(
230
- theme=gr.themes.Soft(),
231
- css="""
232
- .video-preview { max-width: 600px !important; }
233
- .status-box { padding: 10px; border-radius: 5px; margin: 5px; }
234
- .safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
235
- .warning { background: #fff3e0; border: 1px solid #ffcc80; }
236
- .error { background: #ffebee; border: 1px solid #ef9a9a; }
237
- """
238
- ) as app:
239
-
240
- session_id = gr.State(str(uuid.uuid4()))
241
-
242
- gr.Markdown("# 🌊 Wan-2.1-14B-VACE")
243
- gr.Markdown("""
244
- VACE is an all-in-one model designed for video creation and editing. It encompasses various tasks, including reference-to-video generation (R2V), video-to-video editing (V2V), and masked video-to-video editing (MV2V), allowing users to compose these tasks freely. This functionality enables users to explore diverse possibilities and streamlines their workflows effectively, offering a range of capabilities, such as Move-Anything, Swap-Anything, Reference-Anything, Expand-Anything, Animate-Anything, and more.
245
- """)
246
- gr.Markdown("""
247
- [WaveSpeedAI](https://wavespeed.ai/) 提供先进的AI视频生成加速技术
248
- """)
249
-
250
- with gr.Row():
251
- with gr.Column(scale=1):
252
- img_input = gr.File(
253
- file_count="multiple",
254
- file_types=["image"],
255
- label="upload 2 images"
256
- )
257
- prompt = gr.Textbox(label="prompt", lines=3, placeholder="请输入描述...")
258
- negative_prompt = gr.Textbox(label="negative_prompt", lines=2)
259
-
260
- with gr.Row():
261
- size = gr.Dropdown(
262
- ["480*832", "832*480"],
263
- value="480*832",
264
- label="resolution"
265
- )
266
- steps = gr.Slider(1, 50, value=30, label="推理步数")
267
- with gr.Row():
268
- duration = gr.Slider(1, 10, value=5, step=1, label="视频时长(秒)")
269
- guidance = gr.Slider(1, 20, value=7, label="引导系数")
270
- with gr.Row():
271
- seed = gr.Number(-1, label="随机种子")
272
- random_seed_btn = gr.Button("随机种子🎲", variant="secondary")
273
- with gr.Row():
274
- enable_safety = gr.Checkbox(label="🔒 安全检测", value=True)
275
- flow_shift = gr.Slider(1, 50, value=16, label="运动幅度")
276
-
277
- with gr.Column(scale=1):
278
- video_output = gr.Video(label="生成结果", format="mp4")
279
- status_output = gr.Textbox(label="系统状态", interactive=False, lines=4)
280
- generate_btn = gr.Button("开始生成", variant="primary")
281
-
282
- gr.Examples(
283
- examples=[[
284
- "The elegant lady carefully selects bags in the boutique...",
285
- [
286
- "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413751234102420_md9ywspl.png",
287
- "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413586520964413_7bkgc9ol.png"
288
- ]
289
- ]],
290
- inputs=[prompt, img_input],
291
- label="示例输入",
292
- examples_per_page=3
293
- )
294
-
295
- random_seed_btn.click(
296
- fn=lambda: random.randint(0, 999999),
297
- outputs=seed
298
- )
299
-
300
- generate_btn.click(
301
- generate_video,
302
- inputs=[
303
- img_input,
304
- prompt,
305
- duration,
306
- enable_safety,
307
- flow_shift,
308
- guidance,
309
- negative_prompt,
310
- steps,
311
- seed,
312
- size,
313
- session_id
314
- ],
315
- outputs=[status_output, video_output]
316
- )
317
-
318
- if __name__ == "__main__":
319
- threading.Thread(target=cleanup_task, daemon=True).start()
320
- app.queue(max_size=4).launch(
321
- server_name="0.0.0.0",
322
- max_threads=16,
323
- share=False
324
- )