import os from dotenv import load_dotenv import torch from torch import autocast from diffusers import StableDiffusionPipeline import gradio as gr # Import Gradio from PIL import Image # Load the environment variables from .env load_dotenv() class StableBuddyApp: def __init__(self): # Set up the Stable Diffusion pipeline model_id = "CompVis/stable-diffusion-v1-4" self.device = "cuda" if torch.cuda.is_available() else "cpu" # Store device as a class attribute # Get the auth_token from the environment variable auth_token = os.getenv("AUTH_TOKEN") if not auth_token: raise ValueError("AUTH_TOKEN environment variable is not set.") # Use float16 for GPU and float32 for CPU to manage VRAM torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.pipe = StableDiffusionPipeline.from_pretrained( model_id, revision='fp16', torch_dtype=torch_dtype, use_auth_token=auth_token ) self.pipe.to(self.device) def generate_image(self, prompt): """Generate an image based on the prompt.""" try: # Use autocast only for GPU if self.device == "cuda": with autocast(self.device): image = self.pipe(prompt, guidance_scale=8.5).images[0] else: image = self.pipe(prompt, guidance_scale=8.5).images[0] # Ensure the directory exists output_dir = 'data' if not os.path.exists(output_dir): os.makedirs(output_dir) # Save the generated image temporarily image_path = os.path.join(output_dir, 'generated_image.png') image.save(image_path) return image_path # Return the image path for Gradio to display except Exception as e: print(f"An error occurred: {e}") return None # Create an instance of the StableBuddyApp stable_buddy_app = StableBuddyApp() # Create Gradio Interface with separate buttons def generate_and_download(prompt): image_path = stable_buddy_app.generate_image(prompt) return image_path, image_path # Return image for display and for download link # Create Gradio Interface iface = gr.Interface( fn=generate_and_download, # Function to call inputs=gr.Textbox(label="Enter Prompt"), # Text input for the prompt outputs=[gr.Image(type="filepath", label="Generated Image"), gr.File(label="Download Image")], # Two outputs for display and download title="Stable Buddy", description="Generate images using Stable Diffusion." ) # Launch the Gradio app iface.launch()