Sof22 commited on
Commit
d44dd3f
·
1 Parent(s): f4fcda6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -10
main.py CHANGED
@@ -1,8 +1,8 @@
1
  from fastapi import FastAPI, Query, File, UploadFile
2
  from fastapi.responses import FileResponse
3
  import torch
4
- from diffusion import Diffusion
5
- from utils import get_id_frame, get_audio_emb, save_video
6
  import shutil
7
  from pathlib import Path
8
 
@@ -12,18 +12,15 @@ app = FastAPI()
12
  async def generate_video(
13
  id_frame_file: UploadFile = File(...),
14
  audio_file: UploadFile = File(...),
15
- gpu: bool = Query(False, description="Use GPU if available"),
16
  id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
17
  inference_steps: int = Query(100, description="Number of inference diffusion steps"),
18
- output: str = Query("output.mp4", description="Path to save the output video")
19
  ):
20
  device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
21
 
22
  print('Loading model...')
23
- # Load your checkpoint here
24
- unet = torch.jit.load("your_checkpoint_path_here")
25
-
26
- # Replace these arguments with the ones from your original args
27
  diffusion_args = {
28
  "in_channels": 3,
29
  "image_size": 128,
@@ -43,9 +40,15 @@ async def generate_video(
43
  shutil.copyfileobj(audio_file.file, buffer)
44
 
45
  id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
46
- audio, audio_emb = get_audio_emb(str(audio_path), "your_encoder_checkpoint_here", device)
47
 
48
- samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0))
 
 
 
 
 
 
49
 
50
  save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
51
  print(f'Results saved at {output}')
 
1
  from fastapi import FastAPI, Query, File, UploadFile
2
  from fastapi.responses import FileResponse
3
  import torch
4
+ from diffusion import Diffusion # Make sure you import your own modules correctly
5
+ from utils import get_id_frame, get_audio_emb, save_video # Make sure you import your own modules correctly
6
  import shutil
7
  from pathlib import Path
8
 
 
12
  async def generate_video(
13
  id_frame_file: UploadFile = File(...),
14
  audio_file: UploadFile = File(...),
15
+ gpu: bool = Query(True, description="Use GPU if available"),
16
  id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
17
  inference_steps: int = Query(100, description="Number of inference diffusion steps"),
18
+ output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video")
19
  ):
20
  device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
21
 
22
  print('Loading model...')
23
+ unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt")
 
 
 
24
  diffusion_args = {
25
  "in_channels": 3,
26
  "image_size": 128,
 
40
  shutil.copyfileobj(audio_file.file, buffer)
41
 
42
  id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
43
+ audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device)
44
 
45
+ unet_args = {
46
+ "n_audio_motion_embs": 2,
47
+ "n_motion_frames": 2,
48
+ "motion_channels": 3
49
+ }
50
+
51
+ samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args)
52
 
53
  save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
54
  print(f'Results saved at {output}')