Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
import gradio as gr # Import Gradio | |
from PIL import Image | |
# Load the environment variables from .env | |
load_dotenv() | |
class StableBuddyApp: | |
def __init__(self): | |
# Set up the Stable Diffusion pipeline | |
model_id = "CompVis/stable-diffusion-v1-4" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Get the auth_token from the environment variable | |
auth_token = os.getenv("AUTH_TOKEN") | |
if not auth_token: | |
raise ValueError("AUTH_TOKEN environment variable is not set.") | |
# Use float16 and fp16 so that stable diffusion can work on 4GB VRAM | |
self.pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, revision='fp16', torch_dtype=torch.float16, use_auth_token=auth_token | |
) | |
self.pipe.to(device) | |
def generate_image(self, prompt): | |
"""Generate an image based on the prompt.""" | |
try: | |
with autocast("cuda"): | |
image = self.pipe(prompt, guidance_scale=8.5).images[0] | |
# Save the generated image temporarily | |
image_path = 'data/generated_image.png' | |
image.save(image_path) | |
return image_path # Return the image path for Gradio to display | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return None # In case of an error, return None | |
# Create an instance of the StableBuddyApp | |
stable_buddy_app = StableBuddyApp() | |
# Create Gradio Interface with separate buttons | |
def generate_and_download(prompt): | |
image_path = stable_buddy_app.generate_image(prompt) | |
return image_path, image_path # Return image for display and for download link | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=generate_and_download, # Function to call | |
inputs=gr.Textbox(label="Enter Prompt"), # Text input for the prompt | |
outputs=[gr.Image(type="filepath", label="Generated Image"), gr.File(label="Download Image")], # Two outputs for display and download | |
title="Stable Buddy", | |
description="Generate images using Stable Diffusion." | |
) | |
# Launch the Gradio app | |
iface.launch() | |