Vilen03's picture
Update app.py
f3da5a9 verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, FluxPipeline
# Initialize models
sd_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
sd_model.to("cuda")
flux_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
flux_model.enable_model_cpu_offload()
def generate_sd_image(prompt):
return sd_model(prompt).images[0]
def generate_flux_image(prompt):
return flux_model(prompt, guidance_scale=0.0, num_inference_steps=4).images[0]
def generate_image(prompt, model_choice):
if model_choice == "Stable Diffusion":
return generate_sd_image(prompt)
else:
return generate_flux_image(prompt)
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Enter your prompt"),
gr.Radio(["Stable Diffusion", "Flux"], label="Choose Model")
],
outputs=gr.Image(type="pil"),
title="Image Generation with Stable Diffusion and Flux",
description="Generate images using Stable Diffusion (Midjourney-like) or Flux models."
)
iface.launch()