Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ from PIL import Image
|
|
9 |
import tempfile
|
10 |
import mimetypes
|
11 |
import uuid
|
|
|
12 |
|
13 |
# Load Firebase credentials from Hugging Face Secrets
|
14 |
firebase_creds = os.getenv("FIREBASE_CREDENTIALS")
|
@@ -35,15 +36,12 @@ if "current_user" not in st.session_state:
|
|
35 |
st.session_state.current_user = None
|
36 |
if "display_name" not in st.session_state:
|
37 |
st.session_state.display_name = None
|
38 |
-
if "
|
39 |
-
|
40 |
-
if "
|
41 |
-
st.session_state.
|
42 |
-
|
43 |
-
|
44 |
-
if "load_more_pressed" not in st.session_state:
|
45 |
-
st.session_state.load_more_pressed = False # Initialize load_more_pressed
|
46 |
-
|
47 |
TOKEN = os.getenv("TOKEN0")
|
48 |
API_URL = os.getenv("API_URL")
|
49 |
token_id = 0
|
@@ -126,8 +124,6 @@ def login_callback():
|
|
126 |
st.session_state.logged_in = True
|
127 |
st.session_state.current_user = user.uid
|
128 |
st.session_state.display_name = user.display_name # Store the display name
|
129 |
-
st.session_state.images_data = [] # Reset images data on login
|
130 |
-
st.session_state.start_index = 0
|
131 |
st.success("Logged in successfully!")
|
132 |
except Exception as e:
|
133 |
st.error(f"Login failed: {e}")
|
@@ -137,8 +133,6 @@ def logout_callback():
|
|
137 |
st.session_state.logged_in = False
|
138 |
st.session_state.current_user = None
|
139 |
st.session_state.display_name = None
|
140 |
-
st.session_state.images_data = [] # Clear images data on logout
|
141 |
-
st.session_state.start_index = 0
|
142 |
st.info("Logged out successfully!")
|
143 |
|
144 |
# Function to get image from url
|
@@ -151,8 +145,6 @@ def get_image_from_url(url):
|
|
151 |
response.raise_for_status()
|
152 |
image = Image.open(BytesIO(response.content))
|
153 |
return image, url # Return the image and the URL
|
154 |
-
|
155 |
-
|
156 |
except requests.exceptions.RequestException as e:
|
157 |
return f"Error fetching image: {e}", None
|
158 |
except Exception as e:
|
@@ -217,7 +209,7 @@ def download_image(image_url):
|
|
217 |
return None
|
218 |
|
219 |
# Function to store image and related data in Firebase
|
220 |
-
def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url):
|
221 |
try:
|
222 |
ref = db.reference(f'users/{user_id}/images')
|
223 |
new_image_ref = ref.push()
|
@@ -226,6 +218,7 @@ def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url):
|
|
226 |
'aspect_ratio': aspect_ratio,
|
227 |
'realism': realism,
|
228 |
'image_url': image_url,
|
|
|
229 |
'timestamp': {'.sv': 'timestamp'}
|
230 |
})
|
231 |
st.success("Image data saved successfully!")
|
@@ -233,11 +226,14 @@ def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url):
|
|
233 |
st.error(f"Failed to save image data: {e}")
|
234 |
|
235 |
#Function to upload image to cloud storage
|
236 |
-
def upload_image_to_storage(image, user_id):
|
237 |
try:
|
238 |
bucket = storage.bucket()
|
239 |
image_id = str(uuid.uuid4())
|
240 |
-
|
|
|
|
|
|
|
241 |
blob = bucket.blob(file_path)
|
242 |
|
243 |
# Convert PIL Image to BytesIO object
|
@@ -271,6 +267,19 @@ def load_image_data(user_id, start_index, batch_size):
|
|
271 |
st.error(f"Failed to fetch image data from database: {e}")
|
272 |
return []
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
# Registration form
|
275 |
def registration_form():
|
276 |
with st.form("Registration"):
|
@@ -288,7 +297,6 @@ def login_form():
|
|
288 |
password = st.text_input("Password", type="password", key="login_password")
|
289 |
submit_button = st.form_submit_button("Login", on_click=login_callback)
|
290 |
|
291 |
-
# Main app screen (after login)
|
292 |
def main_app():
|
293 |
st.subheader(f"Welcome, {st.session_state.display_name}!")
|
294 |
st.write("Enter a prompt below to generate an image.")
|
@@ -308,15 +316,28 @@ def main_app():
|
|
308 |
image_result = generate_image(prompt, aspect_ratio, realism)
|
309 |
if isinstance(image_result, tuple) and len(image_result) == 2 and image_result[0] is not None:
|
310 |
image, image_url = image_result
|
311 |
-
st.image(image, caption="Generated Image", use_column_width=True)
|
312 |
|
313 |
-
#
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
# Store image data in database
|
318 |
-
store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url)
|
319 |
-
st.success("Image stored to database successfully!")
|
320 |
|
321 |
download_path = download_image(image_url)
|
322 |
if download_path:
|
@@ -327,7 +348,6 @@ def main_app():
|
|
327 |
st.warning("Please enter a prompt to generate an image.")
|
328 |
|
329 |
st.header("Your Generated Images")
|
330 |
-
|
331 |
# Initialize the current window, if it doesn't exist in session state
|
332 |
if "current_window_start" not in st.session_state:
|
333 |
st.session_state.current_window_start = 0
|
@@ -363,22 +383,24 @@ def main_app():
|
|
363 |
|
364 |
cols = st.columns(num_images_to_display)
|
365 |
for i, image_data in enumerate(images_for_window):
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
else:
|
374 |
st.write("No image generated yet!")
|
375 |
|
376 |
# Logout button
|
377 |
if st.button("Logout", on_click=logout_callback):
|
378 |
-
pass
|
379 |
-
# Render the appropriate screen based on login status
|
380 |
-
if st.session_state.logged_in:
|
381 |
-
main_app()
|
382 |
-
else:
|
383 |
-
registration_form()
|
384 |
-
login_form()
|
|
|
9 |
import tempfile
|
10 |
import mimetypes
|
11 |
import uuid
|
12 |
+
import io
|
13 |
|
14 |
# Load Firebase credentials from Hugging Face Secrets
|
15 |
firebase_creds = os.getenv("FIREBASE_CREDENTIALS")
|
|
|
36 |
st.session_state.current_user = None
|
37 |
if "display_name" not in st.session_state:
|
38 |
st.session_state.display_name = None
|
39 |
+
if "window_size" not in st.session_state:
|
40 |
+
st.session_state.window_size = 5
|
41 |
+
if "current_window_start" not in st.session_state:
|
42 |
+
st.session_state.current_window_start = 0
|
43 |
+
|
44 |
+
|
|
|
|
|
|
|
45 |
TOKEN = os.getenv("TOKEN0")
|
46 |
API_URL = os.getenv("API_URL")
|
47 |
token_id = 0
|
|
|
124 |
st.session_state.logged_in = True
|
125 |
st.session_state.current_user = user.uid
|
126 |
st.session_state.display_name = user.display_name # Store the display name
|
|
|
|
|
127 |
st.success("Logged in successfully!")
|
128 |
except Exception as e:
|
129 |
st.error(f"Login failed: {e}")
|
|
|
133 |
st.session_state.logged_in = False
|
134 |
st.session_state.current_user = None
|
135 |
st.session_state.display_name = None
|
|
|
|
|
136 |
st.info("Logged out successfully!")
|
137 |
|
138 |
# Function to get image from url
|
|
|
145 |
response.raise_for_status()
|
146 |
image = Image.open(BytesIO(response.content))
|
147 |
return image, url # Return the image and the URL
|
|
|
|
|
148 |
except requests.exceptions.RequestException as e:
|
149 |
return f"Error fetching image: {e}", None
|
150 |
except Exception as e:
|
|
|
209 |
return None
|
210 |
|
211 |
# Function to store image and related data in Firebase
|
212 |
+
def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url, thumbnail_url):
|
213 |
try:
|
214 |
ref = db.reference(f'users/{user_id}/images')
|
215 |
new_image_ref = ref.push()
|
|
|
218 |
'aspect_ratio': aspect_ratio,
|
219 |
'realism': realism,
|
220 |
'image_url': image_url,
|
221 |
+
'thumbnail_url' : thumbnail_url,
|
222 |
'timestamp': {'.sv': 'timestamp'}
|
223 |
})
|
224 |
st.success("Image data saved successfully!")
|
|
|
226 |
st.error(f"Failed to save image data: {e}")
|
227 |
|
228 |
#Function to upload image to cloud storage
|
229 |
+
def upload_image_to_storage(image, user_id, is_thumbnail = False):
|
230 |
try:
|
231 |
bucket = storage.bucket()
|
232 |
image_id = str(uuid.uuid4())
|
233 |
+
if is_thumbnail:
|
234 |
+
file_path = f"user_images/{user_id}/thumbnails/{image_id}.png" # path for thumbnail
|
235 |
+
else:
|
236 |
+
file_path = f"user_images/{user_id}/{image_id}.png" # path for high resolution images
|
237 |
blob = bucket.blob(file_path)
|
238 |
|
239 |
# Convert PIL Image to BytesIO object
|
|
|
267 |
st.error(f"Failed to fetch image data from database: {e}")
|
268 |
return []
|
269 |
|
270 |
+
# Function to create low resolution thumbnail
|
271 |
+
def create_thumbnail(image, thumbnail_size = (150,150)):
|
272 |
+
try:
|
273 |
+
img_byte_arr = BytesIO()
|
274 |
+
image.thumbnail(thumbnail_size)
|
275 |
+
image.save(img_byte_arr, format='PNG')
|
276 |
+
img_byte_arr = img_byte_arr.getvalue()
|
277 |
+
thumbnail = Image.open(io.BytesIO(img_byte_arr)) # convert byte to PIL image
|
278 |
+
return thumbnail
|
279 |
+
except Exception as e:
|
280 |
+
st.error(f"Failed to create thumbnail: {e}")
|
281 |
+
return None
|
282 |
+
|
283 |
# Registration form
|
284 |
def registration_form():
|
285 |
with st.form("Registration"):
|
|
|
297 |
password = st.text_input("Password", type="password", key="login_password")
|
298 |
submit_button = st.form_submit_button("Login", on_click=login_callback)
|
299 |
|
|
|
300 |
def main_app():
|
301 |
st.subheader(f"Welcome, {st.session_state.display_name}!")
|
302 |
st.write("Enter a prompt below to generate an image.")
|
|
|
316 |
image_result = generate_image(prompt, aspect_ratio, realism)
|
317 |
if isinstance(image_result, tuple) and len(image_result) == 2 and image_result[0] is not None:
|
318 |
image, image_url = image_result
|
|
|
319 |
|
320 |
+
#create thumbnail
|
321 |
+
thumbnail = create_thumbnail(image)
|
322 |
+
|
323 |
+
if thumbnail:
|
324 |
+
# Upload thumbnail to cloud storage and store url
|
325 |
+
thumbnail_url = upload_image_to_storage(thumbnail, st.session_state.current_user, is_thumbnail=True)
|
326 |
+
|
327 |
+
if thumbnail_url:
|
328 |
+
# Upload image to cloud storage and store url
|
329 |
+
cloud_storage_url = upload_image_to_storage(image, st.session_state.current_user, is_thumbnail=False)
|
330 |
+
|
331 |
+
if cloud_storage_url:
|
332 |
+
# Store image data in database
|
333 |
+
store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url, thumbnail_url)
|
334 |
+
st.success("Image stored to database successfully!")
|
335 |
+
else:
|
336 |
+
st.error("Failed to upload thumbnail to cloud storage.")
|
337 |
+
else:
|
338 |
+
st.error("Failed to create thumbnail")
|
339 |
|
340 |
+
st.image(image, caption="Generated Image", use_column_width=True) # Display high res image on top
|
|
|
|
|
|
|
341 |
|
342 |
download_path = download_image(image_url)
|
343 |
if download_path:
|
|
|
348 |
st.warning("Please enter a prompt to generate an image.")
|
349 |
|
350 |
st.header("Your Generated Images")
|
|
|
351 |
# Initialize the current window, if it doesn't exist in session state
|
352 |
if "current_window_start" not in st.session_state:
|
353 |
st.session_state.current_window_start = 0
|
|
|
383 |
|
384 |
cols = st.columns(num_images_to_display)
|
385 |
for i, image_data in enumerate(images_for_window):
|
386 |
+
with cols[i]:
|
387 |
+
if image_data.get('thumbnail_url'):
|
388 |
+
if st.button(image_data.get('prompt') ,key=f"thumbnail_{i}"):
|
389 |
+
st.image(image_data['image_url'], width = 200)
|
390 |
+
st.image(image_data['thumbnail_url'], width = 150)
|
391 |
+
st.write(f"**Prompt:** {image_data['prompt']}")
|
392 |
+
st.write(f"**Aspect Ratio:** {image_data['aspect_ratio']}")
|
393 |
+
st.write(f"**Realism:** {image_data['realism']}")
|
394 |
+
st.markdown("---")
|
395 |
+
else:
|
396 |
+
st.image(image_data['image_url'], width = 150)
|
397 |
+
st.write(f"**Prompt:** {image_data['prompt']}")
|
398 |
+
st.write(f"**Aspect Ratio:** {image_data['aspect_ratio']}")
|
399 |
+
st.write(f"**Realism:** {image_data['realism']}")
|
400 |
+
st.markdown("---")
|
401 |
else:
|
402 |
st.write("No image generated yet!")
|
403 |
|
404 |
# Logout button
|
405 |
if st.button("Logout", on_click=logout_callback):
|
406 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|