import torch
import spaces
from diffusers import StableDiffusionPipeline
import gradio as gr

repo = "IDKiro/sdxs-512-0.9"
seed = 42
weight_type = torch.float16

zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' 🤔

# Load model.
pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=weight_type)

generator = pipe

# move to GPU if available
if torch.cuda.is_available():
    generator = generator.to("cuda")

@spaces.GPU(duration=120)
def generate(prompts):
    images = generator(list(prompts)).images
    return [images]


demo = gr.Interface(
    generate,
    "textbox",
    "image",
    title="SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions",
    description="This demo showcases [SDXS](https://arxiv.org/abs/2403.16627)",
    batch=True,
    max_batch_size=4,  # Set the batch size based on your CPU/GPU memory
).queue()

if __name__ == "__main__":
    demo.launch()