Manjushri's picture
Update app.py
aed27f4 verified
raw
history blame
2.32 kB
import gradio as gr
import numpy as np
import torch
import PIL.Image
from diffusers.utils import numpy_to_pil
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
from previewer.modules import Previewer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#torch.cuda.max_memory_allocated(device=device)
#torch.cuda.empty_cache()
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device)
#prior=prior.to(device)
#decoder=decoder.to(device)
#torch.cuda.empty_cache()
def genie(Prompt, Negative_prompt, seed) -> PIL.Image.Image:
previewer = Previewer()
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
previewer.load_state_dict(previewer_state_dict)
def callback_prior(i, t, latents):
output = previewer(latents)
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
return output
callback_steps = 1
previewer.eval().requires_grad_(False).to(device).to(torch.bfloat16)
generator = np.random.seed(0) if seed == 0 else torch.manual_seed(seed)
prior_output = prior(prompt=Prompt,
height=1024,
width=1024,
negative_prompt=Negative_prompt,
guidance_scale=4.0,
num_inference_steps=20,
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
callback_steps=callback_steps)
decoder_output = decoder(image_embeddings=prior_output.image_embeddings.half(),
prompt=Prompt,
negative_prompt=Negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10).images
return decoder_output[0]
gr.Interface(fn=genie, inputs=['text', 'text', gr.Slider(minimum=0, step=1, maximum=9999999999999999, randomize=True, label='Seed: 0 is Random')], outputs='image').launch()