Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import sys
|
|
|
|
|
4 |
import shutil
|
5 |
import uuid
|
6 |
import subprocess
|
@@ -15,6 +17,13 @@ snapshot_download(
|
|
15 |
local_dir = "./checkpoints"
|
16 |
)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
import tempfile
|
19 |
from moviepy.editor import VideoFileClip
|
20 |
from pydub import AudioSegment
|
@@ -78,7 +87,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|
78 |
from accelerate.utils import set_seed
|
79 |
from latentsync.whisper.audio2feature import Audio2Feature
|
80 |
|
81 |
-
|
82 |
def main(video_path, audio_path, progress=gr.Progress(track_tqdm=True)):
|
83 |
inference_ckpt_path = "checkpoints/latentsync_unet.pt"
|
84 |
unet_config_path = "configs/unet/second_stage.yaml"
|
@@ -118,7 +127,7 @@ def main(video_path, audio_path, progress=gr.Progress(track_tqdm=True)):
|
|
118 |
unet, _ = UNet3DConditionModel.from_pretrained(
|
119 |
OmegaConf.to_container(config.model),
|
120 |
inference_ckpt_path, # load checkpoint
|
121 |
-
device=
|
122 |
)
|
123 |
|
124 |
unet = unet.to(dtype=torch.float16)
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import sys
|
4 |
+
import torch
|
5 |
+
import spaces
|
6 |
import shutil
|
7 |
import uuid
|
8 |
import subprocess
|
|
|
17 |
local_dir = "./checkpoints"
|
18 |
)
|
19 |
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
torch_dtype = torch.float16
|
24 |
+
else:
|
25 |
+
torch_dtype = torch.float32
|
26 |
+
|
27 |
import tempfile
|
28 |
from moviepy.editor import VideoFileClip
|
29 |
from pydub import AudioSegment
|
|
|
87 |
from accelerate.utils import set_seed
|
88 |
from latentsync.whisper.audio2feature import Audio2Feature
|
89 |
|
90 |
+
@spaces.GPU
|
91 |
def main(video_path, audio_path, progress=gr.Progress(track_tqdm=True)):
|
92 |
inference_ckpt_path = "checkpoints/latentsync_unet.pt"
|
93 |
unet_config_path = "configs/unet/second_stage.yaml"
|
|
|
127 |
unet, _ = UNet3DConditionModel.from_pretrained(
|
128 |
OmegaConf.to_container(config.model),
|
129 |
inference_ckpt_path, # load checkpoint
|
130 |
+
device=device,
|
131 |
)
|
132 |
|
133 |
unet = unet.to(dtype=torch.float16)
|