File size: 1,944 Bytes
44769ce
738528e
 
 
 
 
 
9497b74
738528e
 
 
9497b74
738528e
 
 
 
 
9497b74
738528e
9497b74
738528e
 
44769ce
738528e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image

# Title and description of the app
st.title("🖼️ Stable Diffusion Image Generator")
st.write("Generate images from text using the Stable Diffusion v2-1 model!")

# Sidebar for user inputs
st.sidebar.title("Input Options")
prompt = st.sidebar.text_input("Enter your prompt", "A futuristic cityscape at sunset, vibrant colors, 8k")
generate_button = st.sidebar.button("Generate Image")

# Load the pipeline when the app starts
@st.cache_resource
def load_pipeline():
    # Use a smaller and more efficient model
    pipe = StableDiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1", 
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        revision="fp16" if torch.cuda.is_available() else None
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = pipe.to(device)
    return pipe

pipe = load_pipeline()

# Generate image when button is clicked
if generate_button:
    st.write(f"### Prompt: {prompt}")
    with st.spinner("Generating image... Please wait."):
        try:
            # Generate the image
            with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
                image = pipe(prompt).images[0]
            
            # Display the generated image
            st.image(image, caption="Generated Image", use_column_width=True)
            
            # Option to download the image
            img_path = "generated_image.png"
            image.save(img_path)
            with open(img_path, "rb") as img_file:
                st.download_button(
                    label="Download Image",
                    data=img_file,
                    file_name="generated_image.png",
                    mime="image/png"
                )
        except Exception as e:
            st.error(f"An error occurred: {e}")