Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import PIL.Image | |
from diffusers.utils import numpy_to_pil | |
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline | |
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS | |
from previewer.modules import Previewer | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
#torch.cuda.max_memory_allocated(device=device) | |
#torch.cuda.empty_cache() | |
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device) | |
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device) | |
#prior=prior.to(device) | |
#decoder=decoder.to(device) | |
#torch.cuda.empty_cache() | |
def genie(Prompt, Negative_prompt, seed) -> PIL.Image.Image: | |
previewer = Previewer() | |
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"] | |
previewer.load_state_dict(previewer_state_dict) | |
def callback_prior(i, t, latents): | |
output = previewer(latents) | |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy()) | |
return output | |
callback_steps = 1 | |
previewer.eval().requires_grad_(False).to(device).to(torch.bfloat16) | |
generator = np.random.seed(0) if seed == 0 else torch.manual_seed(seed) | |
prior_output = prior(prompt=Prompt, | |
height=1024, | |
width=1024, | |
negative_prompt=Negative_prompt, | |
guidance_scale=4.0, | |
num_inference_steps=20, | |
timesteps=DEFAULT_STAGE_C_TIMESTEPS, | |
callback_steps=callback_steps) | |
decoder_output = decoder(image_embeddings=prior_output.image_embeddings.half(), | |
prompt=Prompt, | |
negative_prompt=Negative_prompt, | |
guidance_scale=0.0, | |
output_type="pil", | |
num_inference_steps=10).images | |
return decoder_output[0] | |
gr.Interface(fn=genie, inputs=['text', 'text', gr.Slider(minimum=0, step=1, maximum=9999999999999999, randomize=True, label='Seed: 0 is Random')], outputs='image').launch() |