Spaces:
Sleeping
Sleeping
import streamlit as st | |
import firebase_admin | |
from firebase_admin import credentials | |
from firebase_admin import auth | |
from firebase_admin import firestore | |
from firebase_admin import storage | |
import requests | |
from io import BytesIO | |
from PIL import Image | |
import os | |
import tempfile | |
import mimetypes | |
import uuid | |
TOKEN = os.getenv("TOKEN0") | |
API_URL = os.getenv("API_URL") | |
token_id = 0 | |
tokens_tried = 0 | |
no_of_accounts = 11 | |
model_id = os.getenv("MODEL_ID") | |
# Initialize Firebase Admin SDK | |
cred = credentials.Certificate("path/to/your/firebase/credentials.json") | |
firebase_admin.initialize_app(cred, { | |
'storageBucket': 'your_storage_bucket_url' | |
}) | |
db = firestore.client() | |
bucket = storage.bucket() | |
def get_image_from_url(url): | |
""" | |
Fetches and returns an image from a given URL, converting to PNG if needed. | |
""" | |
try: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
image = Image.open(BytesIO(response.content)) | |
return image, url # Return the image and the URL | |
except requests.exceptions.RequestException as e: | |
return f"Error fetching image: {e}", None | |
except Exception as e: | |
return f"Error processing image: {e}", None | |
def generate_image(prompt, aspect_ratio, realism, user_id): | |
global token_id | |
global TOKEN | |
global tokens_tried | |
global no_of_accounts | |
global model_id | |
payload = { | |
"id": model_id, | |
"inputs": [prompt, aspect_ratio, str(realism).lower()], | |
} | |
headers = {"Authorization": f"Bearer {TOKEN}"} | |
try: | |
response_data = requests.post(API_URL, json=payload, headers=headers).json() | |
if "error" in response_data: | |
if 'error 429' in response_data['error']: | |
if tokens_tried < no_of_accounts: | |
token_id = (token_id + 1) % (no_of_accounts) | |
tokens_tried += 1 | |
TOKEN = os.getenv(f"TOKEN{token_id}") | |
response_data = generate_image(prompt, aspect_ratio, realism, user_id) | |
tokens_tried = 0 | |
return response_data | |
return "No credits available", None, None | |
return response_data, None, None | |
elif "output" in response_data: | |
url = response_data['output'] | |
image, url = get_image_from_url(url) | |
# Store in firebase | |
download_path = download_image(url) | |
if download_path: | |
try: | |
# Create a unique filename for the image | |
image_filename = f"{uuid.uuid4()}{os.path.splitext(download_path)[1]}" | |
# Store file in cloud storage | |
blob = bucket.blob(f'users/{user_id}/{image_filename}') | |
blob.upload_from_filename(download_path) | |
image_url = blob.public_url | |
# Store prompt and image URL in Firestore | |
doc_ref = db.collection("images").add({ | |
"user_id": user_id, | |
"prompt": prompt, | |
"image_url": image_url, | |
}) | |
print(f"Document added with id: {doc_ref.id}") | |
os.remove(download_path) | |
return image, image_url, image_url | |
except Exception as e: | |
print(f"Error uploading image to firebase {e}") | |
return image, None, url | |
else: | |
return image, None, url # Return the image and the URL | |
else: | |
return "Error: Unexpected response from server", None, None | |
except Exception as e: | |
print(f"Error: {e}") | |
return f"Error", None, None | |
def download_image(image_url): | |
if not image_url: | |
return None # Return None if image_url is empty | |
try: | |
response = requests.get(image_url, stream=True) | |
response.raise_for_status() | |
# Get the content type from the headers | |
content_type = response.headers.get('content-type') | |
extension = mimetypes.guess_extension(content_type) | |
if not extension: | |
extension = ".png" # Default to .png if can't determine the extension | |
# Create a temporary file with the correct extension | |
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_file.write(chunk) | |
temp_file_path = tmp_file.name | |
return temp_file_path | |
except Exception as e: | |
return None | |
def create_user_with_email_and_password(email, password): | |
try: | |
user = auth.create_user( | |
email = email, | |
password = password, | |
) | |
return user.uid, None | |
except Exception as e: | |
print(f"Error creating user {e}") | |
return None, str(e) | |
def sign_in_with_email_and_password(email, password): | |
try: | |
user = auth.get_user_by_email(email) | |
return user.uid, None | |
except Exception as e: | |
print(f"Error logging in user: {e}") | |
return None, str(e) | |
def main(): | |
st.title("Image Generator") | |
# Initialize session state | |
if "user_id" not in st.session_state: | |
st.session_state.user_id = None | |
# Check if user is logged in | |
if not st.session_state.user_id: | |
st.subheader("Login/Signup") | |
st.session_state.login_status = False | |
login_option = st.radio("Select Option", ("Login", "Sign Up")) | |
if login_option == "Login": | |
email = st.text_input("Email") | |
password = st.text_input("Password", type="password") | |
if st.button("Log In"): | |
if email and password: | |
user_id, error = sign_in_with_email_and_password(email, password) | |
if user_id: | |
st.success("Logged in successfully!") | |
st.session_state.user_id = user_id | |
st.session_state.login_status = True | |
else: | |
st.error(f"Login Failed: {error}") | |
else: | |
st.error("Please provide both email and password") | |
elif login_option == "Sign Up": | |
email = st.text_input("Email") | |
password = st.text_input("Password", type="password") | |
if st.button("Sign Up"): | |
if email and password: | |
user_id, error = create_user_with_email_and_password(email, password) | |
if user_id: | |
st.success("Signed up successfully!") | |
st.session_state.user_id = user_id | |
st.session_state.login_status = True | |
else: | |
st.error(f"Sign Up Failed: {error}") | |
else: | |
st.error("Please provide both email and password") | |
# Main app logic if logged in | |
if st.session_state.user_id: | |
prompt = st.text_input("Prompt", placeholder="Describe the image you want to generate") | |
aspect_ratio = st.radio( | |
"Aspect Ratio", | |
["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"], | |
index=4 #Default value is 16:9 | |
) | |
realism = st.checkbox("Realism", value=False) | |
if st.button("Generate Image"): | |
image, url, file_url = generate_image(prompt, aspect_ratio, realism, st.session_state.user_id) | |
if isinstance(image, str): | |
st.error(image) | |
elif image: | |
st.image(image, caption="Generated Image", use_column_width=True) | |
if url: | |
st.write(f"Firestore URL: {url}") | |
if file_url: | |
st.download_button(label="Download Image", data=requests.get(file_url).content, file_name="image.jpg", mime="image/jpeg") | |
if st.button("Log Out"): | |
st.session_state.user_id = None | |
st.rerun() | |
if __name__ == "__main__": | |
main() |