File size: 2,283 Bytes
83cf32a
 
 
d44dd3f
 
83cf32a
 
 
 
 
 
 
 
 
d44dd3f
83cf32a
 
d44dd3f
83cf32a
 
 
 
d44dd3f
83cf32a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d44dd3f
83cf32a
d44dd3f
 
 
 
 
 
 
83cf32a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from fastapi import FastAPI, Query, File, UploadFile
from fastapi.responses import FileResponse
import torch
from diffusion import Diffusion  # Make sure you import your own modules correctly
from utils import get_id_frame, get_audio_emb, save_video  # Make sure you import your own modules correctly
import shutil
from pathlib import Path

app = FastAPI()

@app.post("/generate_video/")
async def generate_video(
        id_frame_file: UploadFile = File(...),
        audio_file: UploadFile = File(...),
        gpu: bool = Query(True, description="Use GPU if available"),
        id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
        inference_steps: int = Query(100, description="Number of inference diffusion steps"),
        output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video")
):
    device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'

    print('Loading model...')
    unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt")
    diffusion_args = {
        "in_channels": 3,
        "image_size": 128,
        "out_channels": 6,
        "n_timesteps": 1000,
    }
    diffusion = Diffusion(unet, device, **diffusion_args).to(device)
    diffusion.space(inference_steps)

    # Save uploaded files to disk
    id_frame_path = Path("temp_id_frame.jpg")
    audio_path = Path("temp_audio.mp3")
    with id_frame_path.open("wb") as buffer:
        shutil.copyfileobj(id_frame_file.file, buffer)

    with audio_path.open("wb") as buffer:
        shutil.copyfileobj(audio_file.file, buffer)

    id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
    audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device)

    unet_args = {
        "n_audio_motion_embs": 2,
        "n_motion_frames": 2,
        "motion_channels": 3
    }
    
    samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args)

    save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
    print(f'Results saved at {output}')

    return FileResponse(output)