File size: 5,056 Bytes
e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e e3dcab5 84a1d8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from huggingface_hub import InferenceClient
import base64
import os
import re
from pathlib import Path
import time
def save_video(base64_video: str, output_path: str):
"""Save base64 encoded video to a file"""
# Handle data URI format if present
if base64_video.startswith('data:video/mp4;base64,'):
base64_video = base64_video.split('base64,')[1]
video_bytes = base64.b64decode(base64_video)
with open(output_path, "wb") as f:
f.write(video_bytes)
print(f"Video saved to: {output_path}")
def generate_video(
prompt: str,
endpoint_url: str,
token: str = None,
resolution: str = "1280x720",
video_length: int = 129,
num_inference_steps: int = 30,
seed: int = -1,
guidance_scale: float = 1.0,
flow_shift: float = 7.0,
embedded_guidance_scale: float = 6.0,
enable_riflex: bool = True,
tea_cache: float = 0.0
) -> str:
"""Generate a video using the custom inference endpoint.
Args:
prompt: Text prompt describing the video
endpoint_url: Full URL to the inference endpoint
token: HuggingFace API token for authentication
resolution: Video resolution (default: "1280x720")
video_length: Number of frames (default: 129)
num_inference_steps: Number of inference steps (default: 30)
seed: Random seed, -1 for random (default: -1)
guidance_scale: Guidance scale value (default: 1.0)
flow_shift: Flow shift value (default: 7.0)
embedded_guidance_scale: Embedded guidance scale (default: 6.0)
enable_riflex: Enable RIFLEx positional embedding for long videos (default: True)
tea_cache: TeaCache acceleration threshold, 0.0 to disable, 0.1 for 1.6x speedup, 0.15 for 2.1x speedup (default: 0.0)
Returns:
Path to the saved video file
"""
# Initialize client
client = InferenceClient(model=endpoint_url, token=token)
print(f"Generating video with prompt: \"{prompt}\"")
print(f"Resolution: {resolution}, Length: {video_length} frames")
print(f"Steps: {num_inference_steps}, Seed: {'random' if seed == -1 else seed}")
# Sanitize filename from prompt
safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip().replace(' ', '_')
# Prepare payload
payload = {
"inputs": prompt,
"resolution": resolution,
"video_length": video_length,
"num_inference_steps": num_inference_steps,
"seed": seed,
"guidance_scale": guidance_scale,
"flow_shift": flow_shift,
"embedded_guidance_scale": embedded_guidance_scale,
"enable_riflex": enable_riflex,
"tea_cache": tea_cache
}
# Make request
start_time = time.time()
print("Sending request to endpoint...")
try:
response = client.post(json=payload)
# Check if the response is a string (data URI) or JSON
if response.headers.get('content-type') == 'application/json':
result = response.json()
video_data = result.get("video_base64", result)
else:
# The response might be directly the data URI
video_data = response.text
generation_time = time.time() - start_time
print(f"Video generated in {generation_time:.2f} seconds")
# Save video
timestamp = int(time.time())
output_path = f"{safe_prompt}_{timestamp}.mp4"
# If the response is a data URI, extract the base64 part
if isinstance(video_data, str) and video_data.startswith('data:video/mp4;base64,'):
save_video(video_data, output_path)
elif isinstance(video_data, str):
save_video(video_data, output_path)
else:
# Assume it's a dictionary with a base64 key
save_video(video_data.get("video_base64", ""), output_path)
return output_path
except Exception as e:
print(f"Error generating video: {e}")
raise
if __name__ == "__main__":
hf_api_token = os.environ.get('HF_API_TOKEN', '')
endpoint_url = os.environ.get('ENDPOINT_URL', '')
if not endpoint_url:
print("Please set the ENDPOINT_URL environment variable")
exit(1)
video_path = generate_video(
endpoint_url=endpoint_url,
token=hf_api_token,
prompt="A cat walks on the grass, realistic style.",
# Video configuration
resolution="1280x720", # Standard HD resolution
video_length=97, # About 4 seconds at 24fps
# Generation parameters
num_inference_steps=22, # Default for standard model
seed=-1, # Random seed
# Advanced parameters
guidance_scale=1.0,
embedded_guidance_scale=6.0,
flow_shift=7.0,
# Optimizations
enable_riflex=True, # Better for videos longer than 4 seconds
tea_cache=0.0 # Set to 0.1 or 0.15 for faster generation with slight quality loss
) |