File size: 5,056 Bytes
e3dcab5
 
 
84a1d8e
e3dcab5
 
 
 
 
84a1d8e
 
 
 
e3dcab5
 
 
84a1d8e
e3dcab5
 
 
 
 
 
 
84a1d8e
e3dcab5
 
 
84a1d8e
 
 
e3dcab5
 
 
 
 
 
 
 
84a1d8e
 
e3dcab5
 
 
 
84a1d8e
 
e3dcab5
 
 
 
 
 
 
84a1d8e
 
 
 
 
 
 
e3dcab5
 
 
 
 
 
 
 
 
84a1d8e
 
 
e3dcab5
 
 
84a1d8e
 
e3dcab5
84a1d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3dcab5
 
 
 
84a1d8e
 
 
 
 
e3dcab5
 
 
 
84a1d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from huggingface_hub import InferenceClient
import base64
import os
import re
from pathlib import Path
import time

def save_video(base64_video: str, output_path: str):
    """Save base64 encoded video to a file"""
    # Handle data URI format if present
    if base64_video.startswith('data:video/mp4;base64,'):
        base64_video = base64_video.split('base64,')[1]
    
    video_bytes = base64.b64decode(base64_video)
    with open(output_path, "wb") as f:
        f.write(video_bytes)
    print(f"Video saved to: {output_path}")

def generate_video(
    prompt: str,
    endpoint_url: str,
    token: str = None,
    resolution: str = "1280x720",
    video_length: int = 129,
    num_inference_steps: int = 30,
    seed: int = -1,
    guidance_scale: float = 1.0,
    flow_shift: float = 7.0,
    embedded_guidance_scale: float = 6.0,
    enable_riflex: bool = True,
    tea_cache: float = 0.0
) -> str:
    """Generate a video using the custom inference endpoint.
    
    Args:
        prompt: Text prompt describing the video
        endpoint_url: Full URL to the inference endpoint
        token: HuggingFace API token for authentication
        resolution: Video resolution (default: "1280x720")
        video_length: Number of frames (default: 129)
        num_inference_steps: Number of inference steps (default: 30) 
        seed: Random seed, -1 for random (default: -1)
        guidance_scale: Guidance scale value (default: 1.0)
        flow_shift: Flow shift value (default: 7.0)
        embedded_guidance_scale: Embedded guidance scale (default: 6.0)
        enable_riflex: Enable RIFLEx positional embedding for long videos (default: True)
        tea_cache: TeaCache acceleration threshold, 0.0 to disable, 0.1 for 1.6x speedup, 0.15 for 2.1x speedup (default: 0.0)
    
    Returns:
        Path to the saved video file
    """
    # Initialize client
    client = InferenceClient(model=endpoint_url, token=token)
    
    print(f"Generating video with prompt: \"{prompt}\"")
    print(f"Resolution: {resolution}, Length: {video_length} frames")
    print(f"Steps: {num_inference_steps}, Seed: {'random' if seed == -1 else seed}")
    
    # Sanitize filename from prompt
    safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip().replace(' ', '_')
    
    # Prepare payload
    payload = {
        "inputs": prompt,
        "resolution": resolution,
        "video_length": video_length,
        "num_inference_steps": num_inference_steps,
        "seed": seed,
        "guidance_scale": guidance_scale,
        "flow_shift": flow_shift,
        "embedded_guidance_scale": embedded_guidance_scale,
        "enable_riflex": enable_riflex,
        "tea_cache": tea_cache
    }

    # Make request
    start_time = time.time()
    print("Sending request to endpoint...")
    
    try:
        response = client.post(json=payload)
        
        # Check if the response is a string (data URI) or JSON
        if response.headers.get('content-type') == 'application/json':
            result = response.json()
            video_data = result.get("video_base64", result)
        else:
            # The response might be directly the data URI
            video_data = response.text
        
        generation_time = time.time() - start_time
        print(f"Video generated in {generation_time:.2f} seconds")
        
        # Save video
        timestamp = int(time.time())
        output_path = f"{safe_prompt}_{timestamp}.mp4"
        
        # If the response is a data URI, extract the base64 part
        if isinstance(video_data, str) and video_data.startswith('data:video/mp4;base64,'):
            save_video(video_data, output_path)
        elif isinstance(video_data, str):
            save_video(video_data, output_path)
        else:
            # Assume it's a dictionary with a base64 key
            save_video(video_data.get("video_base64", ""), output_path)
        
        return output_path
        
    except Exception as e:
        print(f"Error generating video: {e}")
        raise

if __name__ == "__main__":
    hf_api_token = os.environ.get('HF_API_TOKEN', '')
    endpoint_url = os.environ.get('ENDPOINT_URL', '')
    
    if not endpoint_url:
        print("Please set the ENDPOINT_URL environment variable")
        exit(1)
    
    video_path = generate_video(
        endpoint_url=endpoint_url,
        token=hf_api_token,
        prompt="A cat walks on the grass, realistic style.",
        
        # Video configuration
        resolution="1280x720",  # Standard HD resolution
        video_length=97,  # About 4 seconds at 24fps
        
        # Generation parameters
        num_inference_steps=22,  # Default for standard model
        seed=-1,  # Random seed
        
        # Advanced parameters
        guidance_scale=1.0,
        embedded_guidance_scale=6.0,
        flow_shift=7.0,
        
        # Optimizations
        enable_riflex=True,  # Better for videos longer than 4 seconds
        tea_cache=0.0  # Set to 0.1 or 0.15 for faster generation with slight quality loss
    )