Sergidev commited on
Commit
feae090
·
1 Parent(s): b75ec24
Files changed (1) hide show
  1. demo_app.py +19 -8
demo_app.py CHANGED
@@ -101,14 +101,25 @@ def generate(
101
  generator = torch.Generator('cuda').manual_seed(seed_value)
102
 
103
  with torch.amp.autocast_mode.autocast('cuda', dtype=torch.bfloat16), torch.inference_mode(), torch.no_grad():
104
- output = pipe(
105
- prompt=prompt,
106
- height=height,
107
- width=width,
108
- num_frames=num_frames,
109
- num_inference_steps=num_inference_steps,
110
- generator=generator,
111
- ).frames[0]
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  output_path = "output.mp4"
114
  export_to_video(output, output_path, fps=fps)
 
101
  generator = torch.Generator('cuda').manual_seed(seed_value)
102
 
103
  with torch.amp.autocast_mode.autocast('cuda', dtype=torch.bfloat16), torch.inference_mode(), torch.no_grad():
104
+ # Use image input if provided, else use text prompt
105
+ if image_input:
106
+ output = pipe(
107
+ image=Image.open(image_input).convert("RGB"),
108
+ height=height,
109
+ width=width,
110
+ num_frames=num_frames,
111
+ num_inference_steps=num_inference_steps,
112
+ generator=generator,
113
+ ).frames[0]
114
+ else:
115
+ output = pipe(
116
+ prompt=prompt,
117
+ height=height,
118
+ width=width,
119
+ num_frames=num_frames,
120
+ num_inference_steps=num_inference_steps,
121
+ generator=generator,
122
+ ).frames[0]
123
 
124
  output_path = "output.mp4"
125
  export_to_video(output, output_path, fps=fps)