Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,20 @@
|
|
1 |
import streamlit as st
|
2 |
import firebase_admin
|
3 |
-
from firebase_admin import credentials, auth, db
|
4 |
import os
|
5 |
import json
|
6 |
import requests
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Load Firebase credentials from Hugging Face Secrets
|
9 |
firebase_creds = os.getenv("FIREBASE_CREDENTIALS")
|
10 |
FIREBASE_API_KEY = os.getenv("FIREBASE_API_KEY")
|
|
|
|
|
11 |
if firebase_creds:
|
12 |
firebase_creds = json.loads(firebase_creds)
|
13 |
else:
|
@@ -17,7 +24,8 @@ else:
|
|
17 |
if not firebase_admin._apps:
|
18 |
cred = credentials.Certificate(firebase_creds)
|
19 |
firebase_admin.initialize_app(cred, {
|
20 |
-
'databaseURL': 'https://creative-623ef-default-rtdb.firebaseio.com/'
|
|
|
21 |
})
|
22 |
|
23 |
# Initialize session state
|
@@ -28,6 +36,12 @@ if "current_user" not in st.session_state:
|
|
28 |
if "display_name" not in st.session_state:
|
29 |
st.session_state.display_name = None
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def send_verification_email(id_token):
|
33 |
url = f'https://identitytoolkit.googleapis.com/v1/accounts:sendOobCode?key={FIREBASE_API_KEY}'
|
@@ -115,18 +129,117 @@ def logout_callback():
|
|
115 |
st.session_state.display_name = None
|
116 |
st.info("Logged out successfully!")
|
117 |
|
118 |
-
# Function to
|
119 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
try:
|
121 |
-
ref = db.reference(f'users/{user_id}/
|
122 |
-
|
123 |
-
|
124 |
-
'
|
|
|
|
|
|
|
125 |
'timestamp': {'.sv': 'timestamp'}
|
126 |
})
|
127 |
-
st.success("
|
128 |
except Exception as e:
|
129 |
-
st.error(f"Failed to save
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
# Registration form
|
132 |
def registration_form():
|
@@ -148,16 +261,40 @@ def login_form():
|
|
148 |
# Main app screen (after login)
|
149 |
def main_app():
|
150 |
st.subheader(f"Welcome, {st.session_state.display_name}!")
|
151 |
-
st.write("Enter a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
# Logout button
|
163 |
if st.button("Logout", on_click=logout_callback):
|
|
|
1 |
import streamlit as st
|
2 |
import firebase_admin
|
3 |
+
from firebase_admin import credentials, auth, db, storage
|
4 |
import os
|
5 |
import json
|
6 |
import requests
|
7 |
+
from io import BytesIO
|
8 |
+
from PIL import Image
|
9 |
+
import tempfile
|
10 |
+
import mimetypes
|
11 |
+
import uuid
|
12 |
|
13 |
# Load Firebase credentials from Hugging Face Secrets
|
14 |
firebase_creds = os.getenv("FIREBASE_CREDENTIALS")
|
15 |
FIREBASE_API_KEY = os.getenv("FIREBASE_API_KEY")
|
16 |
+
FIREBASE_STORAGE_BUCKET = os.getenv("FIREBASE_STORAGE_BUCKET")
|
17 |
+
|
18 |
if firebase_creds:
|
19 |
firebase_creds = json.loads(firebase_creds)
|
20 |
else:
|
|
|
24 |
if not firebase_admin._apps:
|
25 |
cred = credentials.Certificate(firebase_creds)
|
26 |
firebase_admin.initialize_app(cred, {
|
27 |
+
'databaseURL': 'https://creative-623ef-default-rtdb.firebaseio.com/',
|
28 |
+
'storageBucket': FIREBASE_STORAGE_BUCKET
|
29 |
})
|
30 |
|
31 |
# Initialize session state
|
|
|
36 |
if "display_name" not in st.session_state:
|
37 |
st.session_state.display_name = None
|
38 |
|
39 |
+
TOKEN = os.getenv("TOKEN0")
|
40 |
+
API_URL = os.getenv("API_URL")
|
41 |
+
token_id = 0
|
42 |
+
tokens_tried = 0
|
43 |
+
no_of_accounts = 7
|
44 |
+
model_id = os.getenv("MODEL_ID")
|
45 |
|
46 |
def send_verification_email(id_token):
|
47 |
url = f'https://identitytoolkit.googleapis.com/v1/accounts:sendOobCode?key={FIREBASE_API_KEY}'
|
|
|
129 |
st.session_state.display_name = None
|
130 |
st.info("Logged out successfully!")
|
131 |
|
132 |
+
# Function to get image from url
|
133 |
+
def get_image_from_url(url):
|
134 |
+
"""
|
135 |
+
Fetches and returns an image from a given URL, converting to PNG if needed.
|
136 |
+
"""
|
137 |
+
try:
|
138 |
+
response = requests.get(url, stream=True)
|
139 |
+
response.raise_for_status()
|
140 |
+
image = Image.open(BytesIO(response.content))
|
141 |
+
return image, url # Return the image and the URL
|
142 |
+
|
143 |
+
|
144 |
+
except requests.exceptions.RequestException as e:
|
145 |
+
return f"Error fetching image: {e}", None
|
146 |
+
except Exception as e:
|
147 |
+
return f"Error processing image: {e}", None
|
148 |
+
|
149 |
+
# Function to generate image
|
150 |
+
def generate_image(prompt, aspect_ratio, realism):
|
151 |
+
global token_id
|
152 |
+
global TOKEN
|
153 |
+
global tokens_tried
|
154 |
+
global no_of_accounts
|
155 |
+
global model_id
|
156 |
+
payload = {
|
157 |
+
"id": model_id,
|
158 |
+
"inputs": [prompt, aspect_ratio, str(realism).lower()],
|
159 |
+
}
|
160 |
+
headers = {"Authorization": f"Bearer {TOKEN}"}
|
161 |
+
|
162 |
+
try:
|
163 |
+
response_data = requests.post(API_URL, json=payload, headers=headers).json()
|
164 |
+
if "error" in response_data:
|
165 |
+
if 'error 429' in response_data['error']:
|
166 |
+
if tokens_tried < no_of_accounts:
|
167 |
+
token_id = (token_id + 1) % (no_of_accounts)
|
168 |
+
tokens_tried += 1
|
169 |
+
TOKEN = os.getenv(f"TOKEN{token_id}")
|
170 |
+
response_data = generate_image(prompt, aspect_ratio, realism)
|
171 |
+
tokens_tried = 0
|
172 |
+
return response_data
|
173 |
+
return "No credits available", None
|
174 |
+
return response_data, None
|
175 |
+
elif "output" in response_data:
|
176 |
+
url = response_data['output']
|
177 |
+
image, url = get_image_from_url(url)
|
178 |
+
return image, url # Return the image and the URL
|
179 |
+
else:
|
180 |
+
return "Error: Unexpected response from server", None
|
181 |
+
except Exception as e:
|
182 |
+
return f"Error", None
|
183 |
+
|
184 |
+
def download_image(image_url):
|
185 |
+
if not image_url:
|
186 |
+
return None # Return None if image_url is empty
|
187 |
+
try:
|
188 |
+
response = requests.get(image_url, stream=True)
|
189 |
+
response.raise_for_status()
|
190 |
+
|
191 |
+
# Get the content type from the headers
|
192 |
+
content_type = response.headers.get('content-type')
|
193 |
+
extension = mimetypes.guess_extension(content_type)
|
194 |
+
|
195 |
+
if not extension:
|
196 |
+
extension = ".png" # Default to .png if can't determine the extension
|
197 |
+
|
198 |
+
# Create a temporary file with the correct extension
|
199 |
+
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file:
|
200 |
+
for chunk in response.iter_content(chunk_size=8192):
|
201 |
+
tmp_file.write(chunk)
|
202 |
+
temp_file_path = tmp_file.name
|
203 |
+
return temp_file_path
|
204 |
+
except Exception as e:
|
205 |
+
return None
|
206 |
+
|
207 |
+
# Function to store image and related data in Firebase
|
208 |
+
def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url):
|
209 |
try:
|
210 |
+
ref = db.reference(f'users/{user_id}/images')
|
211 |
+
new_image_ref = ref.push()
|
212 |
+
new_image_ref.set({
|
213 |
+
'prompt': prompt,
|
214 |
+
'aspect_ratio': aspect_ratio,
|
215 |
+
'realism': realism,
|
216 |
+
'image_url': image_url,
|
217 |
'timestamp': {'.sv': 'timestamp'}
|
218 |
})
|
219 |
+
st.success("Image data saved successfully!")
|
220 |
except Exception as e:
|
221 |
+
st.error(f"Failed to save image data: {e}")
|
222 |
+
|
223 |
+
#Function to upload image to cloud storage
|
224 |
+
def upload_image_to_storage(image, user_id):
|
225 |
+
try:
|
226 |
+
bucket = storage.bucket()
|
227 |
+
image_id = str(uuid.uuid4())
|
228 |
+
file_path = f"user_images/{user_id}/{image_id}.png"
|
229 |
+
blob = bucket.blob(file_path)
|
230 |
+
|
231 |
+
# Convert PIL Image to BytesIO object
|
232 |
+
img_byte_arr = BytesIO()
|
233 |
+
image.save(img_byte_arr, format='PNG')
|
234 |
+
img_byte_arr = img_byte_arr.getvalue()
|
235 |
+
|
236 |
+
blob.upload_from_string(img_byte_arr, content_type='image/png')
|
237 |
+
blob.make_public()
|
238 |
+
image_url = blob.public_url
|
239 |
+
return image_url
|
240 |
+
except Exception as e:
|
241 |
+
st.error(f"Failed to upload image to cloud storage: {e}")
|
242 |
+
return None
|
243 |
|
244 |
# Registration form
|
245 |
def registration_form():
|
|
|
261 |
# Main app screen (after login)
|
262 |
def main_app():
|
263 |
st.subheader(f"Welcome, {st.session_state.display_name}!")
|
264 |
+
st.write("Enter a prompt below to generate an image.")
|
265 |
+
|
266 |
+
# Input fields
|
267 |
+
prompt = st.text_input("Prompt", key="image_prompt", placeholder="Describe the image you want to generate")
|
268 |
+
aspect_ratio = st.radio(
|
269 |
+
"Aspect Ratio",
|
270 |
+
options=["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"],
|
271 |
+
index=5
|
272 |
+
)
|
273 |
+
realism = st.checkbox("Realism", value=False)
|
274 |
|
275 |
+
if st.button("Generate Image"):
|
276 |
+
if prompt:
|
277 |
+
with st.spinner("Generating Image..."):
|
278 |
+
image_result = generate_image(prompt, aspect_ratio, realism)
|
279 |
+
if isinstance(image_result, tuple) and len(image_result) == 2 and image_result[0] is not None:
|
280 |
+
image, image_url = image_result
|
281 |
+
st.image(image, caption="Generated Image", use_column_width=True)
|
282 |
|
283 |
+
# Upload image to cloud storage and store url
|
284 |
+
cloud_storage_url = upload_image_to_storage(image, st.session_state.current_user)
|
285 |
+
|
286 |
+
if cloud_storage_url:
|
287 |
+
# Store image data in database
|
288 |
+
store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url)
|
289 |
+
st.success("Image stored to database successfully!")
|
290 |
+
|
291 |
+
download_path = download_image(image_url)
|
292 |
+
if download_path:
|
293 |
+
st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png")
|
294 |
+
else:
|
295 |
+
st.error(f"Image generation failed: {image_result}")
|
296 |
+
else:
|
297 |
+
st.warning("Please enter a prompt to generate an image.")
|
298 |
|
299 |
# Logout button
|
300 |
if st.button("Logout", on_click=logout_callback):
|