Spaces:
Runtime error
Runtime error
File size: 1,169 Bytes
408b580 f3da5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, FluxPipeline
# Initialize models
sd_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
sd_model.to("cuda")
flux_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
flux_model.enable_model_cpu_offload()
def generate_sd_image(prompt):
return sd_model(prompt).images[0]
def generate_flux_image(prompt):
return flux_model(prompt, guidance_scale=0.0, num_inference_steps=4).images[0]
def generate_image(prompt, model_choice):
if model_choice == "Stable Diffusion":
return generate_sd_image(prompt)
else:
return generate_flux_image(prompt)
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Enter your prompt"),
gr.Radio(["Stable Diffusion", "Flux"], label="Choose Model")
],
outputs=gr.Image(type="pil"),
title="Image Generation with Stable Diffusion and Flux",
description="Generate images using Stable Diffusion (Midjourney-like) or Flux models."
)
iface.launch()
|