from huggingface_hub import InferenceClient import base64 import os import re from pathlib import Path import time def save_video(base64_video: str, output_path: str): """Save base64 encoded video to a file""" # Handle data URI format if present if base64_video.startswith('data:video/mp4;base64,'): base64_video = base64_video.split('base64,')[1] video_bytes = base64.b64decode(base64_video) with open(output_path, "wb") as f: f.write(video_bytes) print(f"Video saved to: {output_path}") def generate_video( prompt: str, endpoint_url: str, token: str = None, resolution: str = "1280x720", video_length: int = 129, num_inference_steps: int = 30, seed: int = -1, guidance_scale: float = 1.0, flow_shift: float = 7.0, embedded_guidance_scale: float = 6.0, enable_riflex: bool = True, tea_cache: float = 0.0 ) -> str: """Generate a video using the custom inference endpoint. Args: prompt: Text prompt describing the video endpoint_url: Full URL to the inference endpoint token: HuggingFace API token for authentication resolution: Video resolution (default: "1280x720") video_length: Number of frames (default: 129) num_inference_steps: Number of inference steps (default: 30) seed: Random seed, -1 for random (default: -1) guidance_scale: Guidance scale value (default: 1.0) flow_shift: Flow shift value (default: 7.0) embedded_guidance_scale: Embedded guidance scale (default: 6.0) enable_riflex: Enable RIFLEx positional embedding for long videos (default: True) tea_cache: TeaCache acceleration threshold, 0.0 to disable, 0.1 for 1.6x speedup, 0.15 for 2.1x speedup (default: 0.0) Returns: Path to the saved video file """ # Initialize client client = InferenceClient(model=endpoint_url, token=token) print(f"Generating video with prompt: \"{prompt}\"") print(f"Resolution: {resolution}, Length: {video_length} frames") print(f"Steps: {num_inference_steps}, Seed: {'random' if seed == -1 else seed}") # Sanitize filename from prompt safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip().replace(' ', '_') # Prepare payload payload = { "inputs": prompt, "resolution": resolution, "video_length": video_length, "num_inference_steps": num_inference_steps, "seed": seed, "guidance_scale": guidance_scale, "flow_shift": flow_shift, "embedded_guidance_scale": embedded_guidance_scale, "enable_riflex": enable_riflex, "tea_cache": tea_cache } # Make request start_time = time.time() print("Sending request to endpoint...") try: response = client.post(json=payload) # Check if the response is a string (data URI) or JSON if response.headers.get('content-type') == 'application/json': result = response.json() video_data = result.get("video_base64", result) else: # The response might be directly the data URI video_data = response.text generation_time = time.time() - start_time print(f"Video generated in {generation_time:.2f} seconds") # Save video timestamp = int(time.time()) output_path = f"{safe_prompt}_{timestamp}.mp4" # If the response is a data URI, extract the base64 part if isinstance(video_data, str) and video_data.startswith('data:video/mp4;base64,'): save_video(video_data, output_path) elif isinstance(video_data, str): save_video(video_data, output_path) else: # Assume it's a dictionary with a base64 key save_video(video_data.get("video_base64", ""), output_path) return output_path except Exception as e: print(f"Error generating video: {e}") raise if __name__ == "__main__": hf_api_token = os.environ.get('HF_API_TOKEN', '') endpoint_url = os.environ.get('ENDPOINT_URL', '') if not endpoint_url: print("Please set the ENDPOINT_URL environment variable") exit(1) video_path = generate_video( endpoint_url=endpoint_url, token=hf_api_token, prompt="A cat walks on the grass, realistic style.", # Video configuration resolution="1280x720", # Standard HD resolution video_length=97, # About 4 seconds at 24fps # Generation parameters num_inference_steps=22, # Default for standard model seed=-1, # Random seed # Advanced parameters guidance_scale=1.0, embedded_guidance_scale=6.0, flow_shift=7.0, # Optimizations enable_riflex=True, # Better for videos longer than 4 seconds tea_cache=0.0 # Set to 0.1 or 0.15 for faster generation with slight quality loss )