Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline, FluxPipeline | |
# Initialize models | |
sd_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) | |
sd_model.to("cuda") | |
flux_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) | |
flux_model.enable_model_cpu_offload() | |
def generate_sd_image(prompt): | |
return sd_model(prompt).images[0] | |
def generate_flux_image(prompt): | |
return flux_model(prompt, guidance_scale=0.0, num_inference_steps=4).images[0] | |
def generate_image(prompt, model_choice): | |
if model_choice == "Stable Diffusion": | |
return generate_sd_image(prompt) | |
else: | |
return generate_flux_image(prompt) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(label="Enter your prompt"), | |
gr.Radio(["Stable Diffusion", "Flux"], label="Choose Model") | |
], | |
outputs=gr.Image(type="pil"), | |
title="Image Generation with Stable Diffusion and Flux", | |
description="Generate images using Stable Diffusion (Midjourney-like) or Flux models." | |
) | |
iface.launch() | |