Stable-Cascade / app.py
ehristoforu's picture
Update app.py
9b2326c verified
raw
history blame
1.09 kB
import torch
import gradio as gr
from PIL import image
import spaces
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
device = "cuda"
num_images_per_prompt = 1
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device)
prompt = "Anthropomorphic cat dressed as a pilot"
negative_prompt = ""
@spaces.GPU
def gen(prompt, negative, width, height):
prior_output = prior(
prompt=prompt,
height=height,
width=width,
negative_prompt=negative,
guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=20
)
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
negative_prompt=negative,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images
return decoder_output