File size: 2,733 Bytes
218f4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442bd7f
218f4d0
 
 
 
 
 
442bd7f
 
218f4d0
442bd7f
218f4d0
442bd7f
218f4d0
 
 
 
d28510c
 
 
 
 
218f4d0
 
78542a8
 
 
 
 
218f4d0
78542a8
218f4d0
 
 
 
 
 
442bd7f
 
218f4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from dotenv import load_dotenv
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import gradio as gr  # Import Gradio
from PIL import Image

# Load the environment variables from .env
load_dotenv()

class StableBuddyApp:
    def __init__(self):
        # Set up the Stable Diffusion pipeline
        model_id = "CompVis/stable-diffusion-v1-4"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"  # Store device as a class attribute
        
        # Get the auth_token from the environment variable
        auth_token = os.getenv("AUTH_TOKEN")
        if not auth_token:
            raise ValueError("AUTH_TOKEN environment variable is not set.")
        
        # Use float16 for GPU and float32 for CPU to manage VRAM
        torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id, revision='fp16', torch_dtype=torch_dtype, use_auth_token=auth_token
        )
        self.pipe.to(self.device)

    def generate_image(self, prompt):
        """Generate an image based on the prompt."""
        try:
            # Use autocast only for GPU
            if self.device == "cuda":
                with autocast(self.device):
                    image = self.pipe(prompt, guidance_scale=8.5).images[0]
            else:
                image = self.pipe(prompt, guidance_scale=8.5).images[0]
            
            # Ensure the directory exists
            output_dir = 'data'
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            # Save the generated image temporarily
            image_path = os.path.join(output_dir, 'generated_image.png')
            image.save(image_path)
            
            return image_path  # Return the image path for Gradio to display
            
        except Exception as e:
            print(f"An error occurred: {e}")
            return None

# Create an instance of the StableBuddyApp
stable_buddy_app = StableBuddyApp()

# Create Gradio Interface with separate buttons
def generate_and_download(prompt):
    image_path = stable_buddy_app.generate_image(prompt)
    return image_path, image_path  # Return image for display and for download link

# Create Gradio Interface
iface = gr.Interface(
    fn=generate_and_download,  # Function to call
    inputs=gr.Textbox(label="Enter Prompt"),  # Text input for the prompt
    outputs=[gr.Image(type="filepath", label="Generated Image"), gr.File(label="Download Image")],  # Two outputs for display and download
    title="Stable Buddy",
    description="Generate images using Stable Diffusion."
)

# Launch the Gradio app
iface.launch()