File size: 3,422 Bytes
2af7c18
eaa76b7
2af7c18
 
 
 
 
 
b766640
2af7c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaa76b7
 
2af7c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8097377
2af7c18
882b96e
2af7c18
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import numpy as np
import torch
import spaces
from diffusers import FluxPipeline, FluxTransformer2DModel
from PIL import Image
from diffusers.utils import export_to_gif
import uuid
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float32

def split_image(input_image, num_splits=4):
    # Create a list to store the output images
    output_images = []

    # Split the image into four 256x256 sections
    for i in range(num_splits):
        left = i * 256
        right = (i + 1) * 256
        box = (left, 0, right, 256)
        output_images.append(input_image.crop(box))

    return output_images

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch_dtype
)
pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max

@spaces.GPU 
def infer(prompt, seed, randomize_seed, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
    prompt_template = f"A side by side 4 frame image showing consecutive stills from a looped gif moving from left to right. The gif is {prompt}"
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    generator = torch.Generator().manual_seed(seed)
    
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=1,
        generator=torch.Generator(device).manual_seed(seed),
        height=height,
        width=width
    ).images[0]
    gif_name = f"{uuid.uuid4().hex}-flux.gif"
    export_to_gif(split_image(image, 4), gif_name, fps=4)
    return gif_name, seed

examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # FLUX.1 Schnell Animations
        Generate gifs with 
        """)
        
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=12,
                    step=1,
                    value=4, 
            )    
                
        
        gr.Examples(
            examples = examples,
            inputs = [prompt]
        )
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = [prompt, seed, randomize_seed, num_inference_steps],
        outputs = [result, seed]
    )

demo.queue().launch()