Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
from fastapi import FastAPI, Query, File, UploadFile
|
2 |
from fastapi.responses import FileResponse
|
3 |
import torch
|
4 |
-
from diffusion import Diffusion
|
5 |
-
from utils import get_id_frame, get_audio_emb, save_video
|
6 |
import shutil
|
7 |
from pathlib import Path
|
8 |
|
@@ -12,18 +12,15 @@ app = FastAPI()
|
|
12 |
async def generate_video(
|
13 |
id_frame_file: UploadFile = File(...),
|
14 |
audio_file: UploadFile = File(...),
|
15 |
-
gpu: bool = Query(
|
16 |
id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
|
17 |
inference_steps: int = Query(100, description="Number of inference diffusion steps"),
|
18 |
-
output: str = Query("
|
19 |
):
|
20 |
device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
|
21 |
|
22 |
print('Loading model...')
|
23 |
-
|
24 |
-
unet = torch.jit.load("your_checkpoint_path_here")
|
25 |
-
|
26 |
-
# Replace these arguments with the ones from your original args
|
27 |
diffusion_args = {
|
28 |
"in_channels": 3,
|
29 |
"image_size": 128,
|
@@ -43,9 +40,15 @@ async def generate_video(
|
|
43 |
shutil.copyfileobj(audio_file.file, buffer)
|
44 |
|
45 |
id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
|
46 |
-
audio, audio_emb = get_audio_emb(str(audio_path), "
|
47 |
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
|
51 |
print(f'Results saved at {output}')
|
|
|
1 |
from fastapi import FastAPI, Query, File, UploadFile
|
2 |
from fastapi.responses import FileResponse
|
3 |
import torch
|
4 |
+
from diffusion import Diffusion # Make sure you import your own modules correctly
|
5 |
+
from utils import get_id_frame, get_audio_emb, save_video # Make sure you import your own modules correctly
|
6 |
import shutil
|
7 |
from pathlib import Path
|
8 |
|
|
|
12 |
async def generate_video(
|
13 |
id_frame_file: UploadFile = File(...),
|
14 |
audio_file: UploadFile = File(...),
|
15 |
+
gpu: bool = Query(True, description="Use GPU if available"),
|
16 |
id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
|
17 |
inference_steps: int = Query(100, description="Number of inference diffusion steps"),
|
18 |
+
output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video")
|
19 |
):
|
20 |
device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
|
21 |
|
22 |
print('Loading model...')
|
23 |
+
unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt")
|
|
|
|
|
|
|
24 |
diffusion_args = {
|
25 |
"in_channels": 3,
|
26 |
"image_size": 128,
|
|
|
40 |
shutil.copyfileobj(audio_file.file, buffer)
|
41 |
|
42 |
id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
|
43 |
+
audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device)
|
44 |
|
45 |
+
unet_args = {
|
46 |
+
"n_audio_motion_embs": 2,
|
47 |
+
"n_motion_frames": 2,
|
48 |
+
"motion_channels": 3
|
49 |
+
}
|
50 |
+
|
51 |
+
samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args)
|
52 |
|
53 |
save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
|
54 |
print(f'Results saved at {output}')
|