Sof22 commited on
Commit
83cf32a
·
1 Parent(s): bcd4fe1

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +53 -0
main.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query, File, UploadFile
2
+ from fastapi.responses import FileResponse
3
+ import torch
4
+ from diffusion import Diffusion
5
+ from utils import get_id_frame, get_audio_emb, save_video
6
+ import shutil
7
+ from pathlib import Path
8
+
9
+ app = FastAPI()
10
+
11
+ @app.post("/generate_video/")
12
+ async def generate_video(
13
+ id_frame_file: UploadFile = File(...),
14
+ audio_file: UploadFile = File(...),
15
+ gpu: bool = Query(False, description="Use GPU if available"),
16
+ id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
17
+ inference_steps: int = Query(100, description="Number of inference diffusion steps"),
18
+ output: str = Query("output.mp4", description="Path to save the output video")
19
+ ):
20
+ device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
21
+
22
+ print('Loading model...')
23
+ # Load your checkpoint here
24
+ unet = torch.jit.load("your_checkpoint_path_here")
25
+
26
+ # Replace these arguments with the ones from your original args
27
+ diffusion_args = {
28
+ "in_channels": 3,
29
+ "image_size": 128,
30
+ "out_channels": 6,
31
+ "n_timesteps": 1000,
32
+ }
33
+ diffusion = Diffusion(unet, device, **diffusion_args).to(device)
34
+ diffusion.space(inference_steps)
35
+
36
+ # Save uploaded files to disk
37
+ id_frame_path = Path("temp_id_frame.jpg")
38
+ audio_path = Path("temp_audio.mp3")
39
+ with id_frame_path.open("wb") as buffer:
40
+ shutil.copyfileobj(id_frame_file.file, buffer)
41
+
42
+ with audio_path.open("wb") as buffer:
43
+ shutil.copyfileobj(audio_file.file, buffer)
44
+
45
+ id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
46
+ audio, audio_emb = get_audio_emb(str(audio_path), "your_encoder_checkpoint_here", device)
47
+
48
+ samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0))
49
+
50
+ save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
51
+ print(f'Results saved at {output}')
52
+
53
+ return FileResponse(output)