StableBuddy / app.py
CHEONMA010's picture
Update app.py
78542a8 verified
import os
from dotenv import load_dotenv
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import gradio as gr # Import Gradio
from PIL import Image
# Load the environment variables from .env
load_dotenv()
class StableBuddyApp:
def __init__(self):
# Set up the Stable Diffusion pipeline
model_id = "CompVis/stable-diffusion-v1-4"
self.device = "cuda" if torch.cuda.is_available() else "cpu" # Store device as a class attribute
# Get the auth_token from the environment variable
auth_token = os.getenv("AUTH_TOKEN")
if not auth_token:
raise ValueError("AUTH_TOKEN environment variable is not set.")
# Use float16 for GPU and float32 for CPU to manage VRAM
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id, revision='fp16', torch_dtype=torch_dtype, use_auth_token=auth_token
)
self.pipe.to(self.device)
def generate_image(self, prompt):
"""Generate an image based on the prompt."""
try:
# Use autocast only for GPU
if self.device == "cuda":
with autocast(self.device):
image = self.pipe(prompt, guidance_scale=8.5).images[0]
else:
image = self.pipe(prompt, guidance_scale=8.5).images[0]
# Ensure the directory exists
output_dir = 'data'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Save the generated image temporarily
image_path = os.path.join(output_dir, 'generated_image.png')
image.save(image_path)
return image_path # Return the image path for Gradio to display
except Exception as e:
print(f"An error occurred: {e}")
return None
# Create an instance of the StableBuddyApp
stable_buddy_app = StableBuddyApp()
# Create Gradio Interface with separate buttons
def generate_and_download(prompt):
image_path = stable_buddy_app.generate_image(prompt)
return image_path, image_path # Return image for display and for download link
# Create Gradio Interface
iface = gr.Interface(
fn=generate_and_download, # Function to call
inputs=gr.Textbox(label="Enter Prompt"), # Text input for the prompt
outputs=[gr.Image(type="filepath", label="Generated Image"), gr.File(label="Download Image")], # Two outputs for display and download
title="Stable Buddy",
description="Generate images using Stable Diffusion."
)
# Launch the Gradio app
iface.launch()