File size: 20,080 Bytes
10dce96
 
ca14d7c
43b2fb8
5c827eb
c36922a
ca14d7c
 
 
 
 
54225eb
ceda783
32ba94f
 
11f08dd
ca14d7c
 
32ba94f
 
 
 
580c204
0f68bb2
 
ceda783
580c204
ca14d7c
 
580c204
10dce96
924e5d5
 
 
 
 
e689a79
 
54225eb
 
 
 
4e552cf
 
54225eb
ca14d7c
 
 
 
f866df7
ca14d7c
10dce96
11f08dd
 
 
 
 
 
 
ba6281a
11f08dd
 
ba6281a
11f08dd
 
 
 
e689a79
924e5d5
 
 
 
e689a79
10dce96
e689a79
924e5d5
ba6281a
e689a79
 
ba6281a
11f08dd
 
e689a79
11f08dd
 
 
 
 
 
 
 
 
 
ceda783
 
11f08dd
 
 
 
 
 
 
 
924e5d5
 
10dce96
6e31aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba6281a
6e31aea
 
 
 
 
 
 
 
10dce96
67ca7f8
 
 
 
 
 
 
 
ca14d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba6281a
ca14d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54225eb
580c204
ca14d7c
 
 
 
 
 
 
54225eb
f2725dc
580c204
ca14d7c
580c204
ca14d7c
 
 
54225eb
ca14d7c
 
 
54225eb
 
 
33a3a69
ca14d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
580c204
9fc847d
 
 
 
 
 
 
 
ba6281a
9fc847d
 
 
 
 
 
 
 
 
 
54225eb
 
 
 
 
 
 
 
 
 
 
 
 
10dce96
 
 
 
 
e689a79
10dce96
924e5d5
10dce96
924e5d5
 
 
 
e689a79
924e5d5
 
6e16db5
924e5d5
e689a79
ca14d7c
 
 
 
 
ba6281a
 
ca14d7c
 
 
10dce96
ca14d7c
2f22314
 
 
dd20192
c62b281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54225eb
c62b281
dd20192
ca14d7c
2f22314
 
 
9fc847d
4e552cf
3f0f169
1345050
 
 
 
ba6281a
4e552cf
 
ba6281a
3f0f169
 
 
 
 
1345050
3f0f169
 
 
1345050
3f0f169
 
1345050
ba6281a
1345050
 
 
 
 
 
ba6281a
1345050
3f0f169
1345050
 
b6f4a85
 
1345050
d8544b6
4e552cf
 
 
86035bf
 
4e552cf
 
 
 
 
 
b9d5493
1345050
9fc847d
4e552cf
 
ba6281a
 
 
 
 
 
 
 
86035bf
 
4e552cf
 
924e5d5
 
7f7b93a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
import streamlit as st
import firebase_admin
from firebase_admin import credentials, auth, db, storage
import os
import json
import requests
from io import BytesIO
from PIL import Image
import tempfile
import mimetypes
import uuid
import io

# Load Firebase credentials from Hugging Face Secrets
firebase_creds = os.getenv("FIREBASE_CREDENTIALS")
FIREBASE_API_KEY = os.getenv("FIREBASE_API_KEY")
FIREBASE_STORAGE_BUCKET = os.getenv("FIREBASE_STORAGE_BUCKET")

if firebase_creds:
    firebase_creds = json.loads(firebase_creds)
else:
    st.error("Firebase credentials not found. Please check your secrets.")

# Initialize Firebase (only once)
if not firebase_admin._apps:
    cred = credentials.Certificate(firebase_creds)
    firebase_admin.initialize_app(cred, {
        'databaseURL': 'https://creative-623ef-default-rtdb.firebaseio.com/',
        'storageBucket': FIREBASE_STORAGE_BUCKET
    })

# Initialize session state
if "logged_in" not in st.session_state:
    st.session_state.logged_in = False
if "current_user" not in st.session_state:
    st.session_state.current_user = None
if "display_name" not in st.session_state:
    st.session_state.display_name = None
if "window_size" not in st.session_state:
    st.session_state.window_size = 5
if "current_window_start" not in st.session_state:
    st.session_state.current_window_start = 0
if "selected_image" not in st.session_state:
   st.session_state.selected_image = None

TOKEN = os.getenv("TOKEN0")
API_URL = os.getenv("API_URL")
token_id = 0
tokens_tried = 0
no_of_accounts = 7
model_id = os.getenv("MODEL_ID")

def send_verification_email(id_token):
    url = f'https://identitytoolkit.googleapis.com/v1/accounts:sendOobCode?key={FIREBASE_API_KEY}'
    headers = {'Content-Type': 'application/json'}
    data = {
        'requestType': 'VERIFY_EMAIL',
        'idToken': id_token
    }

    response = requests.post(url, headers=headers, json=data)
    result = response.json()

    if 'error' in result:
        return {'status': 'error', 'message': result['error']['message']}
    else:
        return {'status': 'success', 'email': result['email']}

# Callback for registration
def register_callback():
    email = st.session_state.reg_email
    password = st.session_state.reg_password
    display_name = st.session_state.reg_display_name
    try:
        # Step 1: Create a new user in Firebase
        user = auth.create_user(email=email, password=password)

        # Step 2: Update the user profile with the display name
        auth.update_user(user.uid, display_name=display_name)

        st.success("Registration successful! Sending verification email...")

        # Step 3: Sign in the user programmatically to get the id_token
        url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}'
        data = {
            'email': email,
            'password': password,
            'returnSecureToken': True
        }
        response = requests.post(url, json=data)
        result = response.json()

        if 'idToken' in result:
            id_token = result['idToken']
            st.session_state.id_token = id_token

            verification_result = send_verification_email(id_token)
            if verification_result['status'] == 'success':
                st.success(f"Verification email sent to {email}.")
            else:
                st.error(f"Failed to send verification email: {verification_result['message']}")
        else:
            st.error(f"Failed to retrieve id_token: {result['error']['message']}")
    except Exception as e:
        st.error(f"Registration failed: {e}")

# Callback for login
def login_callback():
    login_identifier = st.session_state.login_identifier
    password = st.session_state.login_password
    try:
        # Try to sign in the user programmatically to check the password validity
        url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}'
        data = {
            'email': login_identifier,
            'password': password,
            'returnSecureToken': True
        }
        response = requests.post(url, json=data)
        result = response.json()

        if 'idToken' in result:
             # If sign in was successful, then use email to fetch the user
             user = auth.get_user_by_email(login_identifier)
             st.session_state.logged_in = True
             st.session_state.current_user = user.uid
             st.session_state.display_name = user.display_name # Store the display name
             st.success("Logged in successfully!")

        elif 'error' in result:
           # If sign-in fails, retrieve user using display name
            try:
                user_list = auth.list_users()
                for user_info in user_list.users:
                  if user_info.display_name == login_identifier:
                    user = user_info
                    # If user is found using display name, try signing in using email
                    url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}'
                    data = {
                        'email': user.email,
                        'password': password,
                        'returnSecureToken': True
                    }
                    response = requests.post(url, json=data)
                    result = response.json()
                    if 'idToken' in result:
                      st.session_state.logged_in = True
                      st.session_state.current_user = user.uid
                      st.session_state.display_name = user.display_name # Store the display name
                      st.success("Logged in successfully!")
                      return

                raise Exception("User not found with provided credentials.")  # if not found, raise exception.
            except Exception as e:
                st.error(f"Login failed: {e}") # if any error, display this message.
        else:
            raise Exception("Error with sign-in endpoint") # If sign-in endpoint doesn't return error or id token, then throw this error.

    except Exception as e:
        st.error(f"Login failed: {e}")

# Callback for logout
def logout_callback():
    st.session_state.logged_in = False
    st.session_state.current_user = None
    st.session_state.display_name = None
    st.session_state.selected_image = None
    st.info("Logged out successfully!")

# Function to get image from url
def get_image_from_url(url):
    """
    Fetches and returns an image from a given URL, converting to PNG if needed.
    """
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return image, url # Return the image and the URL
    except requests.exceptions.RequestException as e:
        return f"Error fetching image: {e}", None
    except Exception as e:
        return f"Error processing image: {e}", None

# Function to generate image
def generate_image(prompt, aspect_ratio, realism):
    global token_id
    global TOKEN
    global tokens_tried
    global no_of_accounts
    global model_id
    payload = {
        "id": model_id,
        "inputs": [prompt, aspect_ratio, str(realism).lower()],
    }
    headers = {"Authorization": f"Bearer {TOKEN}"}

    try:
        response_data = requests.post(API_URL, json=payload, headers=headers).json()
        if "error" in response_data:
            if 'error 429' in response_data['error']:
                if tokens_tried < no_of_accounts:
                    token_id = (token_id + 1) % (no_of_accounts)
                    tokens_tried += 1
                    TOKEN = os.getenv(f"TOKEN{token_id}")
                    response_data = generate_image(prompt, aspect_ratio, realism)
                    tokens_tried = 0
                    return response_data
                return "No credits available", None
            return response_data, None
        elif "output" in response_data:
            url = response_data['output']
            image, url = get_image_from_url(url)
            return image, url  # Return the image and the URL
        else:
            return "Error: Unexpected response from server", None
    except Exception as e:
        return f"Error", None

def download_image(image_url):
    if not image_url:
         return None # Return None if image_url is empty
    try:
        response = requests.get(image_url, stream=True)
        response.raise_for_status()

        # Get the content type from the headers
        content_type = response.headers.get('content-type')
        extension = mimetypes.guess_extension(content_type)

        if not extension:
             extension = ".png" # Default to .png if can't determine the extension

        # Create a temporary file with the correct extension
        with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file:
            for chunk in response.iter_content(chunk_size=8192):
                tmp_file.write(chunk)
            temp_file_path = tmp_file.name
        return temp_file_path
    except Exception as e:
         return None

# Function to store image and related data in Firebase
def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url, thumbnail_url):
    try:
        ref = db.reference(f'users/{user_id}/images')
        new_image_ref = ref.push()
        new_image_ref.set({
            'prompt': prompt,
            'aspect_ratio': aspect_ratio,
            'realism': realism,
            'image_url': image_url,
            'thumbnail_url' : thumbnail_url,
            'timestamp': {'.sv': 'timestamp'}
        })
        st.success("Image data saved successfully!")
    except Exception as e:
        st.error(f"Failed to save image data: {e}")

#Function to upload image to cloud storage
def upload_image_to_storage(image, user_id, is_thumbnail = False):
        try:
            bucket = storage.bucket()
            image_id = str(uuid.uuid4())
            if is_thumbnail:
                 file_path = f"user_images/{user_id}/thumbnails/{image_id}.png" # path for thumbnail
            else:
                 file_path = f"user_images/{user_id}/{image_id}.png" # path for high resolution images
            blob = bucket.blob(file_path)

            # Convert PIL Image to BytesIO object
            img_byte_arr = BytesIO()
            image.save(img_byte_arr, format='PNG')
            img_byte_arr = img_byte_arr.getvalue()

            blob.upload_from_string(img_byte_arr, content_type='image/png')
            blob.make_public()
            image_url = blob.public_url
            return image_url
        except Exception as e:
            st.error(f"Failed to upload image to cloud storage: {e}")
            return None

#Function to load image data from the database
def load_image_data(user_id, start_index, batch_size):
    try:
        ref = db.reference(f'users/{user_id}/images')
        snapshot = ref.order_by_child('timestamp').limit_to_last(start_index + batch_size).get()
        if snapshot:
            image_list = list(snapshot.items())
            image_list.reverse()  # Reverse to show latest first

            new_images = []
            for key, val in image_list[start_index:]:
              new_images.append(val)
            return new_images
        else:
           return []
    except Exception as e:
        st.error(f"Failed to fetch image data from database: {e}")
        return []

# Function to create low resolution thumbnail
def create_thumbnail(image, thumbnail_size = (150,150)):
    try:
        img_byte_arr = BytesIO()
        image.thumbnail(thumbnail_size)
        image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        thumbnail = Image.open(io.BytesIO(img_byte_arr)) # convert byte to PIL image
        return thumbnail
    except Exception as e:
      st.error(f"Failed to create thumbnail: {e}")
      return None

# Registration form
def registration_form():
    with st.form("Registration"):
        st.subheader("Register")
        email = st.text_input("Email", key="reg_email")
        display_name = st.text_input("Display Name", key="reg_display_name")
        password = st.text_input("Password (min 6 characters)", type="password", key="reg_password")
        submit_button = st.form_submit_button("Register", on_click=register_callback)

# Login form
def login_form():
    with st.form("Login"):
        st.subheader("Login")
        login_identifier = st.text_input("Email or Username", key="login_identifier")
        password = st.text_input("Password", type="password", key="login_password")
        submit_button = st.form_submit_button("Login", on_click=login_callback)

def main_app():
    st.subheader(f"Welcome, {st.session_state.display_name}!")
    st.write("Enter a prompt below to generate an image.")

    # Input fields
    prompt = st.text_input("Prompt", key="image_prompt", placeholder="Describe the image you want to generate")
    aspect_ratio = st.radio(
        "Aspect Ratio",
        options=["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"],
        index=5
    )
    realism = st.checkbox("Realism", value=False)

    if st.button("Generate Image"):
        if prompt:
            with st.spinner("Generating Image..."):
                image_result = generate_image(prompt, aspect_ratio, realism)

                if isinstance(image_result, tuple) and len(image_result) == 2:
                    image, image_url = image_result
                    if isinstance(image, Image.Image):
                      # Define the boundary size
                      preview_size = 400

                      # Get original image dimensions
                      original_width, original_height = image.size

                      # Calculate scaling factor to fit within the boundary
                      width_ratio = preview_size / original_width
                      height_ratio = preview_size / original_height

                      scaling_factor = min(width_ratio, height_ratio)

                      # Calculate new dimensions
                      new_width = int(original_width * scaling_factor)
                      new_height = int(original_height * scaling_factor)

                      # Resize the image
                      resized_image = image.resize((new_width, new_height), Image.LANCZOS)

                      # Upload the high-resolution image
                      cloud_storage_url = upload_image_to_storage(image, st.session_state.current_user, is_thumbnail=False)

                      if cloud_storage_url:
                          # Create thumbnail from the high-resolution image
                          thumbnail = create_thumbnail(image)

                          if thumbnail:
                              # Upload thumbnail to cloud storage and store url
                              thumbnail_url = upload_image_to_storage(thumbnail, st.session_state.current_user, is_thumbnail=True)

                              if thumbnail_url:
                                  # Store image data in database
                                  store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url, thumbnail_url)
                                  st.success("Image stored to database successfully!")
                                  with st.container(border=True):
                                      st.image(resized_image, use_column_width=False)  # Display the resized image
                                      st.write(f"**Prompt:** {prompt}")
                                      st.write(f"**Aspect Ratio:** {aspect_ratio}")
                                      st.write(f"**Realism:** {realism}")
                                      download_path = download_image(image_url)
                                      if download_path:
                                          st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_high_res_{uuid.uuid4()}")
                              else:
                                  st.error("Failed to upload thumbnail to cloud storage.")
                          else:
                              st.error("Failed to create thumbnail")
                      else:
                          st.error("Failed to upload image to cloud storage.")
                    else:
                      st.error(f"Image generation failed: {image}")

                else:
                    st.error(f"Image generation failed: {image_result}")
        else:
            st.warning("Please enter a prompt to generate an image.")
    st.header("Your Generated Images")
     # Initialize the current window, if it doesn't exist in session state
    if "current_window_start" not in st.session_state:
      st.session_state.current_window_start = 0

    if "window_size" not in st.session_state:
      st.session_state.window_size = 5 # The number of images to display at a time

    if "selected_image" not in st.session_state:
       st.session_state.selected_image = None

    # Create left and right arrow buttons
    col_left, col_center, col_right = st.columns([1,8,1])

    with col_left:
        if st.button("◀️"):
            st.session_state.current_window_start = max(0, st.session_state.current_window_start - st.session_state.window_size)

    with col_right:
        if st.button("▶️"):
            st.session_state.current_window_start += st.session_state.window_size

    # Dynamically load images for the window
    all_images = load_image_data(st.session_state.current_user, 0, 1000) # load all images

    if all_images:
        num_images = len(all_images)

        # Calculate the range for images to display
        start_index = st.session_state.current_window_start
        end_index = min(start_index + st.session_state.window_size, num_images)

        images_for_window = all_images[start_index:end_index]

        # Setup columns for horizontal slider layout
        num_images_to_display = len(images_for_window)

        cols = st.columns(num_images_to_display)
        for i, image_data in enumerate(images_for_window):
            with cols[i]:
                if image_data.get('thumbnail_url') and image_data.get('image_url'):
                  if st.button("More", key = f"more_{i}"):
                      st.session_state.selected_image = image_data
                  st.image(image_data['thumbnail_url'], width = 150) #display thumbnail

                else:
                    st.image(image_data['image_url'], width = 150)
                    st.write(f"**Prompt:** {image_data['prompt']}")
                    st.write(f"**Aspect Ratio:** {image_data['aspect_ratio']}")
                    st.write(f"**Realism:** {image_data['realism']}")
                    st.markdown("---")
    else:
        st.write("No image generated yet!")

    # Display modal if an image is selected
    if st.session_state.selected_image:
        with st.container(border = True):
           st.image(st.session_state.selected_image['image_url'], use_column_width=True)
           st.write(f"**Prompt:** {st.session_state.selected_image['prompt']}")
           st.write(f"**Aspect Ratio:** {st.session_state.selected_image['aspect_ratio']}")
           st.write(f"**Realism:** {st.session_state.selected_image['realism']}")
           download_path = download_image(st.session_state.selected_image['image_url'])
           if download_path:
              st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_overlay_{uuid.uuid4()}")

        if st.button("Close"):
          st.session_state.selected_image = None # close the modal when "close" is clicked

    # Logout button
    if st.button("Logout", on_click=logout_callback):
        pass
if st.session_state.logged_in:
    main_app()
else:
    registration_form()
    login_form()