jiandan1998 commited on
Commit
b44f555
·
verified ·
1 Parent(s): 1c95961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -40
app.py CHANGED
@@ -4,11 +4,8 @@ import json
4
  import time
5
  import threading
6
  import uuid
7
- import shutil
8
  import base64
9
- from datetime import datetime
10
  from pathlib import Path
11
- from http.server import HTTPServer, SimpleHTTPRequestHandler
12
  from dotenv import load_dotenv
13
  import gradio as gr
14
  import random
@@ -84,7 +81,6 @@ def image_to_base64(file_path):
84
  if len(img_data) == 0:
85
  raise ValueError("空文件")
86
 
87
- # 使用URL安全编码并自动填充
88
  encoded = base64.urlsafe_b64encode(img_data)
89
  missing_padding = len(encoded) % 4
90
  if missing_padding:
@@ -118,7 +114,7 @@ def classify_prompt(prompt):
118
  return torch.argmax(outputs.logits).item()
119
 
120
  def generate_video(
121
- image,
122
  prompt,
123
  duration,
124
  enable_safety,
@@ -131,6 +127,11 @@ def generate_video(
131
  session_id
132
  ):
133
 
 
 
 
 
 
134
  safety_level = classify_prompt(prompt)
135
  if safety_level != 0:
136
  error_img = create_error_image(CLASS_NAMES[safety_level])
@@ -150,24 +151,26 @@ def generate_video(
150
  api_key = os.getenv("WAVESPEED_API_KEY")
151
  if not api_key:
152
  raise ValueError("API key missing")
153
-
154
- base64_img = image_to_base64(image)
 
155
  headers = {
156
  "Authorization": f"Bearer {api_key}",
157
  "Content-Type": "application/json"
158
- }
159
 
160
  payload = {
161
- "context_scale": 1,
162
- "enable_safety_checker": True,
 
 
163
  "flow_shift": flow_shift,
 
164
  "guidance_scale": guidance,
165
- "images": [base64_img],
166
  "negative_prompt": negative_prompt,
167
  "num_inference_steps": steps,
168
- "prompt": prompt,
169
- "seed": seed if seed != -1 else random.randint(0, 999999),
170
- "size": "480*832"
171
  }
172
 
173
  response = requests.post(
@@ -236,44 +239,56 @@ with gr.Blocks(
236
 
237
  session_id = gr.State(str(uuid.uuid4()))
238
 
239
- gr.Markdown("# 🌊 Wan-2.1-i2v-480p-Ultra-Fast Run On WaveSpeedAI")
 
 
 
240
  gr.Markdown("""
241
- [WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
242
- Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries.
243
  """)
244
 
245
  with gr.Row():
246
  with gr.Column(scale=1):
247
- img_input = gr.Image(type="filepath", label="Upload Image")
248
- prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Prompt...")
249
- negative_prompt = gr.Textbox(label="Negative Prompt", lines=2)
 
 
 
 
250
 
251
  with gr.Row():
252
- size = gr.Dropdown(["832*480", "480*832"], value="832 * 480", interactive=True, label="Resolution")
253
- steps = gr.Slider(1, 50, value=30, label="Inference Steps")
 
 
 
 
254
  with gr.Row():
255
- duration = gr.Slider(1, 10, value=5, step=1, label="时长(秒)")
256
- guidance = gr.Slider(1, 20, value=7, label="Guidance Scale")
257
  with gr.Row():
258
- seed = gr.Number(-1, label="Seed")
259
- random_seed_btn = gr.Button("Random🎲Seed", variant="secondary")
260
  with gr.Row():
261
- enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker",value=True, interactive=False)
262
- flow_shift = gr.Slider(1, 50, value=16, label="flow_shift")
263
 
264
  with gr.Column(scale=1):
265
- video_output = gr.Video(label="Generated Video", format="mp4", elem_classes=["video-preview"])
266
- status_output = gr.Textbox(label="System Status", interactive=False, lines=4)
267
- generate_btn = gr.Button("Generate Video", variant="primary")
268
 
269
  gr.Examples(
270
- examples=[
271
- ["The elegant lady carefully selects bags in the boutique, and she shows the charm of a mature woman in a black slim dress with a pearl necklace. Holding a vintage-inspired blue leather half-moon handbag, she is carefully observing its craftsmanship and texture. The interior of the store is a haven of sophistication and luxury. Soft, ambient lighting casts a warm glow over the polished wooden floors",
272
- "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413751234102420_md9ywspl.png"
 
 
273
  ]
274
- ],
275
  inputs=[prompt, img_input],
276
- label="Example Inputs",
277
  examples_per_page=3
278
  )
279
 
@@ -297,10 +312,7 @@ with gr.Blocks(
297
  size,
298
  session_id
299
  ],
300
- outputs=[
301
- status_output,
302
- video_output
303
- ]
304
  )
305
 
306
  if __name__ == "__main__":
 
4
  import time
5
  import threading
6
  import uuid
 
7
  import base64
 
8
  from pathlib import Path
 
9
  from dotenv import load_dotenv
10
  import gradio as gr
11
  import random
 
81
  if len(img_data) == 0:
82
  raise ValueError("空文件")
83
 
 
84
  encoded = base64.urlsafe_b64encode(img_data)
85
  missing_padding = len(encoded) % 4
86
  if missing_padding:
 
114
  return torch.argmax(outputs.logits).item()
115
 
116
  def generate_video(
117
+ image_files,
118
  prompt,
119
  duration,
120
  enable_safety,
 
127
  session_id
128
  ):
129
 
130
+ if len(image_files) != 2:
131
+ error_img = create_error_image("upload 2 images")
132
+ yield "❌ error: upload 2 images", error_img
133
+ return
134
+
135
  safety_level = classify_prompt(prompt)
136
  if safety_level != 0:
137
  error_img = create_error_image(CLASS_NAMES[safety_level])
 
151
  api_key = os.getenv("WAVESPEED_API_KEY")
152
  if not api_key:
153
  raise ValueError("API key missing")
154
+
155
+ base64_images = [image_to_base64(img) for img in image_files]
156
+
157
  headers = {
158
  "Authorization": f"Bearer {api_key}",
159
  "Content-Type": "application/json"
160
+ }
161
 
162
  payload = {
163
+ "seed": seed if seed != -1 else random.randint(0, 999999),
164
+ "size": size.replace(" ", ""),
165
+ "images": base64_images,
166
+ "prompt": prompt,
167
  "flow_shift": flow_shift,
168
+ "context_scale": 1,
169
  "guidance_scale": guidance,
 
170
  "negative_prompt": negative_prompt,
171
  "num_inference_steps": steps,
172
+ "enable_safety_checker": enable_safety,
173
+ "model_id": "wavespeed-ai/wan-2.1-14b-vace"
 
174
  }
175
 
176
  response = requests.post(
 
239
 
240
  session_id = gr.State(str(uuid.uuid4()))
241
 
242
+ gr.Markdown("# 🌊 Wan-2.1-14B-VACE")
243
+ gr.Markdown("
244
+ VACE is an all-in-one model designed for video creation and editing. It encompasses various tasks, including reference-to-video generation (R2V), video-to-video editing (V2V), and masked video-to-video editing (MV2V), allowing users to compose these tasks freely. This functionality enables users to explore diverse possibilities and streamlines their workflows effectively, offering a range of capabilities, such as Move-Anything, Swap-Anything, Reference-Anything, Expand-Anything, Animate-Anything, and more."
245
+ )
246
  gr.Markdown("""
247
+ [WaveSpeedAI](https://wavespeed.ai/) 提供先进的AI视频生成加速技术
 
248
  """)
249
 
250
  with gr.Row():
251
  with gr.Column(scale=1):
252
+ img_input = gr.File(
253
+ file_count="multiple",
254
+ file_types=["image"],
255
+ label="upload 2 images"
256
+ )
257
+ prompt = gr.Textbox(label="prompt", lines=3, placeholder="请输入描述...")
258
+ negative_prompt = gr.Textbox(label="negative_prompt", lines=2)
259
 
260
  with gr.Row():
261
+ size = gr.Dropdown(
262
+ ["480*832", "832*480"],
263
+ value="480*832",
264
+ label="resolution"
265
+ )
266
+ steps = gr.Slider(1, 50, value=30, label="推理步数")
267
  with gr.Row():
268
+ duration = gr.Slider(1, 10, value=5, step=1, label="视频时长(秒)")
269
+ guidance = gr.Slider(1, 20, value=7, label="引导系数")
270
  with gr.Row():
271
+ seed = gr.Number(-1, label="随机种子")
272
+ random_seed_btn = gr.Button("随机种子🎲", variant="secondary")
273
  with gr.Row():
274
+ enable_safety = gr.Checkbox(label="🔒 安全检测", value=True)
275
+ flow_shift = gr.Slider(1, 50, value=16, label="运动幅度")
276
 
277
  with gr.Column(scale=1):
278
+ video_output = gr.Video(label="生成结果", format="mp4")
279
+ status_output = gr.Textbox(label="系统状态", interactive=False, lines=4)
280
+ generate_btn = gr.Button("开始生成", variant="primary")
281
 
282
  gr.Examples(
283
+ examples=[[
284
+ "The elegant lady carefully selects bags in the boutique...",
285
+ [
286
+ "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413751234102420_md9ywspl.png",
287
+ "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413586520964413_7bkgc9ol.png"
288
  ]
289
+ ]],
290
  inputs=[prompt, img_input],
291
+ label="示例输入",
292
  examples_per_page=3
293
  )
294
 
 
312
  size,
313
  session_id
314
  ],
315
+ outputs=[status_output, video_output]
 
 
 
316
  )
317
 
318
  if __name__ == "__main__":