Update app.py
Browse files
app.py
CHANGED
@@ -96,15 +96,6 @@ with st.sidebar:
|
|
96 |
# st.warning("Please enter a prompt.")
|
97 |
|
98 |
# COMPREHENSIVE CODE
|
99 |
-
import streamlit as st
|
100 |
-
import requests
|
101 |
-
from PIL import Image
|
102 |
-
import io
|
103 |
-
|
104 |
-
#------------------------------------------------------------------------
|
105 |
-
# Define functions
|
106 |
-
#------------------------------------------------------------------------
|
107 |
-
|
108 |
def query(payload):
|
109 |
response = requests.post(API_URL, headers=headers, json=payload)
|
110 |
if response.status_code != 200:
|
@@ -128,34 +119,39 @@ def main():
|
|
128 |
st.session_state["prompt"] = ""
|
129 |
|
130 |
# Input field for the prompt
|
131 |
-
prompt = st.text_input("Enter a prompt for image generation:", value=
|
132 |
|
133 |
if st.button("Generate Image"):
|
134 |
if prompt:
|
135 |
image = generate_image(prompt)
|
136 |
if image:
|
137 |
st.session_state["image"] = image # Store generated image in session state
|
138 |
-
st.session_state["prompt"] = prompt # Store the prompt in session state
|
139 |
st.image(image, caption="Generated Image")
|
|
|
140 |
else:
|
141 |
st.warning("Please enter a prompt.")
|
142 |
|
143 |
-
# Show download
|
144 |
if st.session_state["image"]:
|
145 |
image_bytes = io.BytesIO()
|
146 |
st.session_state["image"].save(image_bytes, format='PNG')
|
|
|
|
|
147 |
st.download_button(
|
148 |
label="Download Image",
|
149 |
data=image_bytes.getvalue(),
|
150 |
file_name="generated_image.png",
|
151 |
mime="image/png"
|
152 |
)
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
159 |
|
160 |
#------------------------------------------------------------------------
|
161 |
# Main Guard
|
|
|
96 |
# st.warning("Please enter a prompt.")
|
97 |
|
98 |
# COMPREHENSIVE CODE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def query(payload):
|
100 |
response = requests.post(API_URL, headers=headers, json=payload)
|
101 |
if response.status_code != 200:
|
|
|
119 |
st.session_state["prompt"] = ""
|
120 |
|
121 |
# Input field for the prompt
|
122 |
+
prompt = st.text_input("Enter a prompt for image generation:", value="")
|
123 |
|
124 |
if st.button("Generate Image"):
|
125 |
if prompt:
|
126 |
image = generate_image(prompt)
|
127 |
if image:
|
128 |
st.session_state["image"] = image # Store generated image in session state
|
|
|
129 |
st.image(image, caption="Generated Image")
|
130 |
+
st.session_state["prompt"] = prompt # Store the prompt in session state
|
131 |
else:
|
132 |
st.warning("Please enter a prompt.")
|
133 |
|
134 |
+
# Show download and reset buttons only if an image is generated
|
135 |
if st.session_state["image"]:
|
136 |
image_bytes = io.BytesIO()
|
137 |
st.session_state["image"].save(image_bytes, format='PNG')
|
138 |
+
|
139 |
+
# Download button
|
140 |
st.download_button(
|
141 |
label="Download Image",
|
142 |
data=image_bytes.getvalue(),
|
143 |
file_name="generated_image.png",
|
144 |
mime="image/png"
|
145 |
)
|
146 |
+
|
147 |
+
# Reset button
|
148 |
+
if st.button("Reset"):
|
149 |
+
# Clear session state variables
|
150 |
+
st.session_state["image"] = None
|
151 |
+
st.session_state["prompt"] = ""
|
152 |
+
|
153 |
+
# Clear UI by updating query params (this will force rerun of the app)
|
154 |
+
st.experimental_set_query_params()
|
155 |
|
156 |
#------------------------------------------------------------------------
|
157 |
# Main Guard
|