File size: 4,286 Bytes
6e0bb2d
 
 
 
 
 
ebb1079
6e0bb2d
 
ebb1079
 
 
 
6e0bb2d
 
 
 
 
ebb1079
6e0bb2d
 
 
ebb1079
6e0bb2d
 
 
ebb1079
6e0bb2d
 
ebb1079
6e0bb2d
 
 
ebb1079
 
 
 
 
 
 
 
 
 
 
 
6e0bb2d
ebb1079
6e0bb2d
ebb1079
6e0bb2d
ebb1079
 
 
6e0bb2d
 
ebb1079
 
6e0bb2d
 
 
 
 
 
ebb1079
6e0bb2d
 
 
 
ebb1079
6e0bb2d
 
 
 
 
 
 
 
 
 
 
 
ebb1079
6e0bb2d
 
 
 
 
 
 
ebb1079
6e0bb2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb1079
6e0bb2d
 
 
 
 
 
ebb1079
6e0bb2d
 
 
 
 
 
 
 
ebb1079
6e0bb2d
 
 
 
 
 
ebb1079
6e0bb2d
ebb1079
 
6e0bb2d
ebb1079
 
6e0bb2d
 
ebb1079
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch

# Set the device based on availability
device = "cuda" if torch.cuda.is_available() else "cpu"

# Use the ByteDance/AnimateDiff-Lightning model
model_repo_id = "ByteDance/AnimateDiff-Lightning" 

# Set the torch dtype based on available hardware
if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

# Load the pipeline from the pretrained model repository
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)

# Maximum values for seed and image size
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# Define the inference function
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):

    # Randomize seed if the checkbox is selected
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    generator = torch.Generator(device=device).manual_seed(seed)

    # Generate the animation using the pipeline
    animation = pipe(
        prompt=prompt, 
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps, 
        width=width, 
        height=height,
        generator=generator
    ).images[0]  # Assuming the model generates images in the `.images` property
    
    return animation, seed

# Sample prompts for the UI
examples = [
    "A cat playing with a ball in a garden",
    "A dancing astronaut in space",
    "A flying dragon in the sky at sunset",
]

# Define CSS for styling
css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

# Build the Gradio UI
with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # AnimateDiff Lightning Model Text-to-Animation
        """)
        
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Generated Animation", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                visible=True,
            )
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            
            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=7.5,
                )
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=30,
                )

        # Example prompts for user selection
        gr.Examples(
            examples=examples,
            inputs=[prompt]
        )

    # Create an API endpoint for the model
    demo.api(fn=infer, inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs=[result, seed])
    demo.launch()