Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from diffusers import CogVideoXPipeline, CogVideoXDPMScheduler, CogVideoXTransformer3DModel | |
from huggingface_hub import hf_hub_download, snapshot_download | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# (Optional) Download additional assets for upscaling or interpolation if needed. | |
hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran") | |
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") | |
# Load the text-to-video model using diffusers' CogVideoXPipeline. | |
# (Replace with your model ID if different.) | |
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device) | |
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
# Optionally load an image-to-video transformer if your pipeline supports image conditioning. | |
# (This may be used if you want to condition on an uploaded image.) | |
i2v_transformer = CogVideoXTransformer3DModel.from_pretrained( | |
"THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
def generate_video(prompt, style, duration, image): | |
""" | |
Generate a Pokémon-themed video. | |
The function "flavors" the text prompt with the chosen style and mentions iconic | |
Pokémon elements (e.g. Ash, Pikachu, Team Rocket). The duration (in seconds) is passed | |
to the pipeline (if supported). The 'image' input is optional and may be ignored if the | |
pipeline does not support image conditioning. | |
""" | |
# Build a full prompt by combining user input with style and Pokémon-specific flavor. | |
full_prompt = ( | |
f"{prompt}, in {style} style, lasting {duration} seconds. " | |
"Include iconic Pokémon elements like Ash, Pikachu, and Team Rocket." | |
) | |
# Generate video (adjust inference parameters as needed) | |
result = pipe(full_prompt, num_inference_steps=50, guidance_scale=7.5) | |
# Assuming the pipeline returns a dict with a 'videos' key (a list of generated videos) | |
video = result.videos[0] | |
return video | |
# Build the Gradio UI. | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🎥 PokeVidGen AI") | |
gr.Markdown("Generate Pokémon anime shorts using CogVideoX-5b! Enter your scene prompt, choose an animation style, set the duration, and optionally upload an image.") | |
with gr.Row(): | |
prompt_input = gr.Textbox(label="Enter Pokémon Scene", placeholder="Ash battles Team Rocket with Pikachu's Thunderbolt") | |
style_input = gr.Dropdown(choices=["Anime Classic", "Modern 3D", "Cartoon"], label="Animation Style", value="Anime Classic") | |
duration_input = gr.Slider(minimum=1, maximum=10, step=1, label="Duration (seconds)", value=5) | |
image_input = gr.Image(label="Optional Image", type="filepath") | |
generate_button = gr.Button("Generate Video") | |
video_output = gr.Video(label="Generated Video") | |
generate_button.click(fn=generate_video, inputs=[prompt_input, style_input, duration_input, image_input], outputs=video_output) | |
demo.launch() | |