musev-demo / scripts /runpod_handler.py
jmanhype's picture
Initial Space setup
0a72c84
import runpod
import os
import sys
from pathlib import Path
import torch
import gradio as gr
import tempfile
from PIL import Image
import numpy as np
import yaml
from typing import Dict, Any, Optional
import threading
# Add the MuseV directory to the Python path
musev_path = str(Path(__file__).parent.parent)
sys.path.append(musev_path)
# Import MuseV modules (adjust these imports based on the actual module structure)
from musev.pipelines import MuseVPipeline
class MuseVService:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipeline = None
self.load_model()
def load_model(self):
# Initialize the MuseV pipeline (adjust parameters as needed)
self.pipeline = MuseVPipeline.from_pretrained(
"TMElyralab/MuseV",
torch_dtype=torch.float16,
device=self.device
)
self.pipeline.to(self.device)
def generate_video(
self,
condition_image: Image.Image,
prompt: str,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 16,
eye_blinks_factor: float = 1.8,
) -> str:
# Process inputs
if height is None or width is None:
width, height = condition_image.size
aspect_ratio = width / height
if width > height:
width = min(width, 1024)
height = int(width / aspect_ratio)
else:
height = min(height, 1024)
width = int(height * aspect_ratio)
# Create temporary directory for output
with tempfile.TemporaryDirectory() as temp_dir:
# Save condition image
condition_image_path = os.path.join(temp_dir, "condition.png")
condition_image.save(condition_image_path)
# Prepare configuration
config = {
"condition_images": condition_image_path,
"prompt": prompt,
"height": height,
"width": width,
"eye_blinks_factor": eye_blinks_factor,
"img_length_ratio": 1.0,
"ipadapter_image": condition_image_path,
"refer_image": condition_image_path,
}
# Generate video
output_path = os.path.join(temp_dir, "output.mp4")
self.pipeline.generate(config, output_path)
# Read the video file and return as bytes
with open(output_path, "rb") as f:
video_bytes = f.read()
return video_bytes
# Initialize the service
service = MuseVService()
def handler(event):
"""
RunPod handler function for API requests
"""
try:
# Get the input data
job_input = event["input"]
# Process the input image
image_data = job_input.get("image")
if not image_data:
raise ValueError("No image provided")
# Convert base64 image to PIL
import base64
from io import BytesIO
image = Image.open(BytesIO(base64.b64decode(image_data)))
# Generate video
video_bytes = service.generate_video(
condition_image=image,
prompt=job_input.get("prompt", ""),
height=job_input.get("height"),
width=job_input.get("width"),
eye_blinks_factor=job_input.get("eye_blinks_factor", 1.8),
)
# Encode video as base64
video_base64 = base64.b64encode(video_bytes).decode()
return {
"status": "success",
"output": {
"video": video_base64
}
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}
def create_gradio_interface():
"""
Create Gradio interface
"""
def generate_video_gradio(image, prompt, height, width, eye_blinks_factor):
try:
video_bytes = service.generate_video(
condition_image=Image.fromarray(image),
prompt=prompt,
height=height if height > 0 else None,
width=width if width > 0 else None,
eye_blinks_factor=eye_blinks_factor
)
# Save video to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_file.write(video_bytes)
temp_file.close()
return temp_file.name
except Exception as e:
raise gr.Error(str(e))
# Create the interface
interface = gr.Interface(
fn=generate_video_gradio,
inputs=[
gr.Image(label="Input Image", type="numpy"),
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
gr.Number(label="Height (optional)", value=0),
gr.Number(label="Width (optional)", value=0),
gr.Slider(minimum=0.0, maximum=3.0, value=1.8, label="Eye Blinks Factor")
],
outputs=gr.Video(label="Generated Video"),
title="MuseV Video Generation",
description="Generate videos from images using MuseV",
examples=[
[
"path/to/example/image.jpg",
"(masterpiece, best quality, highres:1),(1person, solo:1),(eye blinks:1.8),(head wave:1.3)",
512,
512,
1.8
]
]
)
return interface
if __name__ == "__main__":
# Start both the RunPod handler and Gradio interface
interface = create_gradio_interface()
# Start Gradio in a separate thread
threading.Thread(
target=interface.launch,
kwargs={
"server_name": "0.0.0.0",
"server_port": 3000,
"share": False
},
daemon=True
).start()
# Start the RunPod handler
runpod.serverless.start({"handler": handler})