from fastapi import FastAPI, Query, File, UploadFile from fastapi.responses import FileResponse import torch from diffusion import Diffusion from utils import get_id_frame, get_audio_emb, save_video import shutil from pathlib import Path app = FastAPI() @app.post("/generate_video/") async def generate_video( id_frame_file: UploadFile = File(...), audio_file: UploadFile = File(...), gpu: bool = Query(False, description="Use GPU if available"), id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"), inference_steps: int = Query(100, description="Number of inference diffusion steps"), output: str = Query("output.mp4", description="Path to save the output video") ): device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu' print('Loading model...') # Load your checkpoint here unet = torch.jit.load("your_checkpoint_path_here") # Replace these arguments with the ones from your original args diffusion_args = { "in_channels": 3, "image_size": 128, "out_channels": 6, "n_timesteps": 1000, } diffusion = Diffusion(unet, device, **diffusion_args).to(device) diffusion.space(inference_steps) # Save uploaded files to disk id_frame_path = Path("temp_id_frame.jpg") audio_path = Path("temp_audio.mp3") with id_frame_path.open("wb") as buffer: shutil.copyfileobj(id_frame_file.file, buffer) with audio_path.open("wb") as buffer: shutil.copyfileobj(audio_file.file, buffer) id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device) audio, audio_emb = get_audio_emb(str(audio_path), "your_encoder_checkpoint_here", device) samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0)) save_video(output, samples, audio=audio, fps=25, audio_rate=16000) print(f'Results saved at {output}') return FileResponse(output)