import gradio as gr import numpy as np import torch import PIL.Image from diffusers.utils import numpy_to_pil from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS from previewer.modules import Previewer device = 'cuda' if torch.cuda.is_available() else 'cpu' #torch.cuda.max_memory_allocated(device=device) #torch.cuda.empty_cache() prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device) decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device) #prior=prior.to(device) #decoder=decoder.to(device) #torch.cuda.empty_cache() def genie(Prompt, Negative_prompt, seed) -> PIL.Image.Image: previewer = Previewer() previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"] previewer.load_state_dict(previewer_state_dict) def callback_prior(i, t, latents): output = previewer(latents) output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy()) return output callback_steps = 1 previewer.eval().requires_grad_(False).to(device).to(torch.bfloat16) generator = np.random.seed(0) if seed == 0 else torch.manual_seed(seed) prior_output = prior(prompt=Prompt, height=1024, width=1024, negative_prompt=Negative_prompt, guidance_scale=4.0, num_inference_steps=20, timesteps=DEFAULT_STAGE_C_TIMESTEPS, callback_steps=callback_steps) decoder_output = decoder(image_embeddings=prior_output.image_embeddings.half(), prompt=Prompt, negative_prompt=Negative_prompt, guidance_scale=0.0, output_type="pil", num_inference_steps=10).images return decoder_output[0] gr.Interface(fn=genie, inputs=['text', 'text', gr.Slider(minimum=0, step=1, maximum=9999999999999999, randomize=True, label='Seed: 0 is Random')], outputs='image').launch()