import runpod import os import sys from pathlib import Path import torch import gradio as gr import tempfile from PIL import Image import numpy as np import yaml from typing import Dict, Any, Optional import threading # Add the MuseV directory to the Python path musev_path = str(Path(__file__).parent.parent) sys.path.append(musev_path) # Import MuseV modules (adjust these imports based on the actual module structure) from musev.pipelines import MuseVPipeline class MuseVService: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.pipeline = None self.load_model() def load_model(self): # Initialize the MuseV pipeline (adjust parameters as needed) self.pipeline = MuseVPipeline.from_pretrained( "TMElyralab/MuseV", torch_dtype=torch.float16, device=self.device ) self.pipeline.to(self.device) def generate_video( self, condition_image: Image.Image, prompt: str, height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 16, eye_blinks_factor: float = 1.8, ) -> str: # Process inputs if height is None or width is None: width, height = condition_image.size aspect_ratio = width / height if width > height: width = min(width, 1024) height = int(width / aspect_ratio) else: height = min(height, 1024) width = int(height * aspect_ratio) # Create temporary directory for output with tempfile.TemporaryDirectory() as temp_dir: # Save condition image condition_image_path = os.path.join(temp_dir, "condition.png") condition_image.save(condition_image_path) # Prepare configuration config = { "condition_images": condition_image_path, "prompt": prompt, "height": height, "width": width, "eye_blinks_factor": eye_blinks_factor, "img_length_ratio": 1.0, "ipadapter_image": condition_image_path, "refer_image": condition_image_path, } # Generate video output_path = os.path.join(temp_dir, "output.mp4") self.pipeline.generate(config, output_path) # Read the video file and return as bytes with open(output_path, "rb") as f: video_bytes = f.read() return video_bytes # Initialize the service service = MuseVService() def handler(event): """ RunPod handler function for API requests """ try: # Get the input data job_input = event["input"] # Process the input image image_data = job_input.get("image") if not image_data: raise ValueError("No image provided") # Convert base64 image to PIL import base64 from io import BytesIO image = Image.open(BytesIO(base64.b64decode(image_data))) # Generate video video_bytes = service.generate_video( condition_image=image, prompt=job_input.get("prompt", ""), height=job_input.get("height"), width=job_input.get("width"), eye_blinks_factor=job_input.get("eye_blinks_factor", 1.8), ) # Encode video as base64 video_base64 = base64.b64encode(video_bytes).decode() return { "status": "success", "output": { "video": video_base64 } } except Exception as e: return { "status": "error", "error": str(e) } def create_gradio_interface(): """ Create Gradio interface """ def generate_video_gradio(image, prompt, height, width, eye_blinks_factor): try: video_bytes = service.generate_video( condition_image=Image.fromarray(image), prompt=prompt, height=height if height > 0 else None, width=width if width > 0 else None, eye_blinks_factor=eye_blinks_factor ) # Save video to temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") temp_file.write(video_bytes) temp_file.close() return temp_file.name except Exception as e: raise gr.Error(str(e)) # Create the interface interface = gr.Interface( fn=generate_video_gradio, inputs=[ gr.Image(label="Input Image", type="numpy"), gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), gr.Number(label="Height (optional)", value=0), gr.Number(label="Width (optional)", value=0), gr.Slider(minimum=0.0, maximum=3.0, value=1.8, label="Eye Blinks Factor") ], outputs=gr.Video(label="Generated Video"), title="MuseV Video Generation", description="Generate videos from images using MuseV", examples=[ [ "path/to/example/image.jpg", "(masterpiece, best quality, highres:1),(1person, solo:1),(eye blinks:1.8),(head wave:1.3)", 512, 512, 1.8 ] ] ) return interface if __name__ == "__main__": # Start both the RunPod handler and Gradio interface interface = create_gradio_interface() # Start Gradio in a separate thread threading.Thread( target=interface.launch, kwargs={ "server_name": "0.0.0.0", "server_port": 3000, "share": False }, daemon=True ).start() # Start the RunPod handler runpod.serverless.start({"handler": handler})