FLUX-Animations / app.py
multimodalart's picture
Create app.py
2af7c18 verified
raw
history blame
3.4 kB
import gradio as gr
import torch
import spaces
from diffusers import FluxPipeline, FluxTransformer2DModel
from PIL import Image
from diffusers.utils import export_to_gif
import uuid
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
def split_image(input_image, num_splits=4):
# Create a list to store the output images
output_images = []
# Split the image into four 256x256 sections
for i in range(num_splits):
left = i * 256
right = (i + 1) * 256
box = (left, 0, right, 256)
output_images.append(input_image.crop(box))
return output_images
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch_dtype
)
pipe.to(device)
@spaces.GPU
def infer(prompt, seed, randomize_seed, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
prompt_template = f"A side by side 4 frame image showing consecutive stills from a looped gif moving from left to right. The gif is {prompt}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=torch.Generator(device).manual_seed(seed),
height=height,
width=width
).images[0]
gif_name = f"{uuid.uuid4().hex}-flux.gif"
export_to_gif(split_image(image, 4), gif_name, fps=4)
return gif_name, seed
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css="""
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# FLUX.1 Schnell Animations
Generate gifs with
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=12,
step=1,
value=4,
)
gr.Examples(
examples = examples,
inputs = [prompt]
)
gr.on(
trigger=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs = [result, seed]
)
demo.queue().launch()