Spaces:
Sleeping
Sleeping
File size: 8,062 Bytes
38c235b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import streamlit as st
import firebase_admin
from firebase_admin import credentials
from firebase_admin import auth
from firebase_admin import firestore
from firebase_admin import storage
import requests
from io import BytesIO
from PIL import Image
import os
import tempfile
import mimetypes
import uuid
TOKEN = os.getenv("TOKEN0")
API_URL = os.getenv("API_URL")
token_id = 0
tokens_tried = 0
no_of_accounts = 11
model_id = os.getenv("MODEL_ID")
# Initialize Firebase Admin SDK
cred = credentials.Certificate("path/to/your/firebase/credentials.json")
firebase_admin.initialize_app(cred, {
'storageBucket': 'your_storage_bucket_url'
})
db = firestore.client()
bucket = storage.bucket()
def get_image_from_url(url):
"""
Fetches and returns an image from a given URL, converting to PNG if needed.
"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return image, url # Return the image and the URL
except requests.exceptions.RequestException as e:
return f"Error fetching image: {e}", None
except Exception as e:
return f"Error processing image: {e}", None
def generate_image(prompt, aspect_ratio, realism, user_id):
global token_id
global TOKEN
global tokens_tried
global no_of_accounts
global model_id
payload = {
"id": model_id,
"inputs": [prompt, aspect_ratio, str(realism).lower()],
}
headers = {"Authorization": f"Bearer {TOKEN}"}
try:
response_data = requests.post(API_URL, json=payload, headers=headers).json()
if "error" in response_data:
if 'error 429' in response_data['error']:
if tokens_tried < no_of_accounts:
token_id = (token_id + 1) % (no_of_accounts)
tokens_tried += 1
TOKEN = os.getenv(f"TOKEN{token_id}")
response_data = generate_image(prompt, aspect_ratio, realism, user_id)
tokens_tried = 0
return response_data
return "No credits available", None, None
return response_data, None, None
elif "output" in response_data:
url = response_data['output']
image, url = get_image_from_url(url)
# Store in firebase
download_path = download_image(url)
if download_path:
try:
# Create a unique filename for the image
image_filename = f"{uuid.uuid4()}{os.path.splitext(download_path)[1]}"
# Store file in cloud storage
blob = bucket.blob(f'users/{user_id}/{image_filename}')
blob.upload_from_filename(download_path)
image_url = blob.public_url
# Store prompt and image URL in Firestore
doc_ref = db.collection("images").add({
"user_id": user_id,
"prompt": prompt,
"image_url": image_url,
})
print(f"Document added with id: {doc_ref.id}")
os.remove(download_path)
return image, image_url, image_url
except Exception as e:
print(f"Error uploading image to firebase {e}")
return image, None, url
else:
return image, None, url # Return the image and the URL
else:
return "Error: Unexpected response from server", None, None
except Exception as e:
print(f"Error: {e}")
return f"Error", None, None
def download_image(image_url):
if not image_url:
return None # Return None if image_url is empty
try:
response = requests.get(image_url, stream=True)
response.raise_for_status()
# Get the content type from the headers
content_type = response.headers.get('content-type')
extension = mimetypes.guess_extension(content_type)
if not extension:
extension = ".png" # Default to .png if can't determine the extension
# Create a temporary file with the correct extension
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file:
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
temp_file_path = tmp_file.name
return temp_file_path
except Exception as e:
return None
def create_user_with_email_and_password(email, password):
try:
user = auth.create_user(
email = email,
password = password,
)
return user.uid, None
except Exception as e:
print(f"Error creating user {e}")
return None, str(e)
def sign_in_with_email_and_password(email, password):
try:
user = auth.get_user_by_email(email)
return user.uid, None
except Exception as e:
print(f"Error logging in user: {e}")
return None, str(e)
def main():
st.title("Image Generator")
# Initialize session state
if "user_id" not in st.session_state:
st.session_state.user_id = None
# Check if user is logged in
if not st.session_state.user_id:
st.subheader("Login/Signup")
st.session_state.login_status = False
login_option = st.radio("Select Option", ("Login", "Sign Up"))
if login_option == "Login":
email = st.text_input("Email")
password = st.text_input("Password", type="password")
if st.button("Log In"):
if email and password:
user_id, error = sign_in_with_email_and_password(email, password)
if user_id:
st.success("Logged in successfully!")
st.session_state.user_id = user_id
st.session_state.login_status = True
else:
st.error(f"Login Failed: {error}")
else:
st.error("Please provide both email and password")
elif login_option == "Sign Up":
email = st.text_input("Email")
password = st.text_input("Password", type="password")
if st.button("Sign Up"):
if email and password:
user_id, error = create_user_with_email_and_password(email, password)
if user_id:
st.success("Signed up successfully!")
st.session_state.user_id = user_id
st.session_state.login_status = True
else:
st.error(f"Sign Up Failed: {error}")
else:
st.error("Please provide both email and password")
# Main app logic if logged in
if st.session_state.user_id:
prompt = st.text_input("Prompt", placeholder="Describe the image you want to generate")
aspect_ratio = st.radio(
"Aspect Ratio",
["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"],
index=4 #Default value is 16:9
)
realism = st.checkbox("Realism", value=False)
if st.button("Generate Image"):
image, url, file_url = generate_image(prompt, aspect_ratio, realism, st.session_state.user_id)
if isinstance(image, str):
st.error(image)
elif image:
st.image(image, caption="Generated Image", use_column_width=True)
if url:
st.write(f"Firestore URL: {url}")
if file_url:
st.download_button(label="Download Image", data=requests.get(file_url).content, file_name="image.jpg", mime="image/jpeg")
if st.button("Log Out"):
st.session_state.user_id = None
st.rerun()
if __name__ == "__main__":
main() |