Cosmos-Predict2 / app.py
multimodalart's picture
faster safety checking
b4b0028 verified
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True
)
import gradio as gr
import spaces
import torch
from diffusers import Cosmos2TextToImagePipeline, EDMEulerScheduler
from transformers import AutoModelForCausalLM, SiglipProcessor
import random
#Add flash_attention_2 to the safeguard model
def patch_from_pretrained(cls):
orig_method = cls.from_pretrained
def new_from_pretrained(*args, **kwargs):
kwargs.setdefault("attn_implementation", "flash_attention_2")
kwargs.setdefault("torch_dtype", torch.bfloat16)
return orig_method(*args, **kwargs)
cls.from_pretrained = new_from_pretrained
patch_from_pretrained(AutoModelForCausalLM)
#Add a `use_fast` to the safeguard image processor
def patch_processor_fast(cls):
orig_method = cls.from_pretrained
def new_from_pretrained(*args, **kwargs):
kwargs.setdefault("use_fast", True)
return orig_method(*args, **kwargs)
cls.from_pretrained = new_from_pretrained
patch_processor_fast(SiglipProcessor)
model_14b_id = "nvidia/Cosmos-Predict2-14B-Text2Image"
pipe_14b = Cosmos2TextToImagePipeline.from_pretrained(
model_14b_id,
torch_dtype=torch.bfloat16
)
pipe_14b.to("cuda")
@spaces.GPU(duration=140)
def generate_image(
prompt,
negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.",
seed=42,
randomize_seed=False,
model_choice="14B",
progress=gr.Progress(track_tqdm=True)
):
if randomize_seed:
actual_seed = random.randint(0, 1000000)
else:
actual_seed = seed
generator = torch.Generator().manual_seed(actual_seed)
output = pipe_14b(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator
).images[0]
return output, actual_seed
example_prompts = [
"A well-worn broom sweeps across a dusty wooden floor, its bristles gathering crumbs and flecks of debris in swift, rhythmic strokes. Dust motes dance in the sunbeams filtering through the window, glowing momentarily before settling. The quiet swish of straw brushing wood is interrupted only by the occasional creak of old floorboards. With each pass, the floor grows cleaner, restoring a sense of quiet order to the humble room.",
"A laundry machine whirs to life, tumbling colorful clothes behind the foggy glass door. Suds begin to form in a frothy dance, clinging to fabric as the drum spins. The gentle thud of shifting clothes creates a steady rhythm, like a heartbeat of the home. Outside the machine, a quiet calm fills the room, anticipation building for the softness and warmth of freshly laundered garments.",
"A robotic arm tightens a bolt beneath the hood of a car, its tool head rotating with practiced torque. The metal-on-metal sound clicks into place, and the arm pauses briefly before retracting with a soft hydraulic hiss. Overhead lights reflect off the glossy vehicle surface, while scattered tools and screens blink in the background—a garage scene reimagined through the lens of precision engineering.",
"A nighttime city bus terminal gradually shifts from stillness to subtle movement. At first, multiple double-decker buses are parked under the glow of overhead lights, with a central bus labeled '87D' facing forward and stationary. As the video progresses, the bus in the middle moves ahead slowly, its headlights brightening the surrounding area and casting reflections onto adjacent vehicles. The motion creates space in the lineup, signaling activity within the otherwise quiet station. It then comes to a smooth stop, resuming its position in line. Overhead signage in Chinese characters remains illuminated, enhancing the vibrant, urban night scene.",
"As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor.",
"In the later moments of the video, the female worker in the front, dressed in a white coat and hairnet, performs a repetitive yet precise task. She scoops golden granular material from a wide jar and steadily pours it into the next empty glass bottle on the conveyor belt. Her hand moves with practiced control as she aligns the scoop over each container, ensuring an even fill. The sequence highlights her focused attention and consistent motion, capturing the shift from preparation to active material handling as the production line advances bottle by bottle.",
"A wide-angle shot captures a sunny suburban street intersection, where the bright sunlight casts sharp shadows across the road. The scene is framed by a row of houses with beige and brown roofs, and lush green lawns. Autumn-colored trees add vibrant red and orange hues to the landscape. Overhead power lines stretch across the sky, and a fire hydrant is visible on the right side of the frame near the curb. A silver sedan is parked on the driveway of a house on the left, while a silver SUV is parked on the street in front of the house at the center of the camera view. The ego vehicle waits to turn right at the t-intersection, yielding to two other vehicles traveling in opposite directions. A black car enters the frame from the right, driving across the intersection and continuing straight ahead. The car's movement is smooth and steady, and it exits the frame to the left. The final frame shows the intersection with a vehicle moving from the left to the right side, the silver sedan and the SUV still parked in their initial positions, and the black car having moved out of view."
]
# Define the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Cosmos-Predict2 14B Text2Image
[[Model]](https://huggingface.co/nvidia/Cosmos-Predict2-14B-Text2Image), [[Code]](https://github.com/nvidia-cosmos/cosmos-predict2)
"""
)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
lines=5,
value="A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess.",
placeholder="Enter your descriptive prompt here..."
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt",
lines=3,
value="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.",
placeholder="Enter what you DON'T want to see in the image..."
)
with gr.Row():
randomize_seed_checkbox = gr.Checkbox(
label="Randomize Seed",
value=True
)
seed_input = gr.Slider(
minimum=0,
maximum=1000000,
value=1,
step=1,
label="Seed"
)
model_radio = gr.Radio(
choices=["14B", "2B"],
value="14B",
label="Model Selection",
visible=False
)
generate_button = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
gr.Examples(
examples=example_prompts,
inputs=[prompt_input],
outputs=[output_image, seed_input],
fn=generate_image,
cache_examples="lazy"
)
generate_button.click(
fn=generate_image,
inputs=[prompt_input, negative_prompt_input, seed_input, randomize_seed_checkbox, model_radio],
outputs=[output_image, seed_input]
)
if __name__ == "__main__":
demo.launch()