diffused-heads / main.py
Sof22's picture
Create main.py
83cf32a
raw
history blame
2 kB
from fastapi import FastAPI, Query, File, UploadFile
from fastapi.responses import FileResponse
import torch
from diffusion import Diffusion
from utils import get_id_frame, get_audio_emb, save_video
import shutil
from pathlib import Path
app = FastAPI()
@app.post("/generate_video/")
async def generate_video(
id_frame_file: UploadFile = File(...),
audio_file: UploadFile = File(...),
gpu: bool = Query(False, description="Use GPU if available"),
id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
inference_steps: int = Query(100, description="Number of inference diffusion steps"),
output: str = Query("output.mp4", description="Path to save the output video")
):
device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
print('Loading model...')
# Load your checkpoint here
unet = torch.jit.load("your_checkpoint_path_here")
# Replace these arguments with the ones from your original args
diffusion_args = {
"in_channels": 3,
"image_size": 128,
"out_channels": 6,
"n_timesteps": 1000,
}
diffusion = Diffusion(unet, device, **diffusion_args).to(device)
diffusion.space(inference_steps)
# Save uploaded files to disk
id_frame_path = Path("temp_id_frame.jpg")
audio_path = Path("temp_audio.mp3")
with id_frame_path.open("wb") as buffer:
shutil.copyfileobj(id_frame_file.file, buffer)
with audio_path.open("wb") as buffer:
shutil.copyfileobj(audio_file.file, buffer)
id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
audio, audio_emb = get_audio_emb(str(audio_path), "your_encoder_checkpoint_here", device)
samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0))
save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
print(f'Results saved at {output}')
return FileResponse(output)