Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import io | |
import google_auth | |
import mongo_db | |
# Functions for image and prompt generation | |
from image_models import ( | |
flux, | |
stable_diffusion, | |
) | |
from prompt_models import Qwen_72b, microsoft_phi, Mixtral | |
# Set page config with a custom title and layout | |
st.set_page_config( | |
page_title="Welcome to ImagiGen: AI-Powered Image Synthesis", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Custom CSS for fonts, background, and animations | |
st.markdown( | |
""" | |
<style> | |
/* Add a custom font */ | |
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap'); | |
body { | |
font-family: 'Poppins', sans-serif; | |
background: linear-gradient(135deg, #f8f9fa, #e9ecef); | |
animation: gradient-animation 10s ease infinite; | |
} | |
/* Header and title styles */ | |
h1, h2, h3 { | |
color: #343a40; | |
text-align: center; | |
} | |
/* Button styles */ | |
.stButton button { | |
background-color: #007bff; | |
color: white; | |
border-radius: 8px; | |
transition: all 0.3s ease; | |
} | |
.stButton button:hover { | |
background-color: #0056b3; | |
transform: scale(1.05); | |
} | |
/* Smooth gradient animation */ | |
@keyframes gradient-animation { | |
0% { background: linear-gradient(135deg, #f8f9fa, #e9ecef); } | |
50% { background: linear-gradient(135deg, #e9ecef, #f8f9fa); } | |
100% { background: linear-gradient(135deg, #f8f9fa, #e9ecef); } | |
} | |
/* Custom input box styling */ | |
input, textarea { | |
border: 2px solid #ced4da; | |
border-radius: 5px; | |
padding: 10px; | |
} | |
/* Sidebar customization */ | |
.sidebar .sidebar-content { | |
background-color: #ffffff; | |
border-radius: 10px; | |
padding: 20px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# List of functions | |
image_functions = [flux, stable_diffusion] | |
prompt_functions = [Qwen_72b, microsoft_phi, Mixtral] | |
# Initialize session state for login status | |
if "logged_in" not in st.session_state: | |
st.session_state.logged_in = False | |
def register(): | |
col1, col2, col3 = st.columns([1, 2, 1]) # Left, Center (main content), Right | |
with col2: | |
with st.expander("Register New Account", expanded=False): | |
st.subheader("Create a New Account", divider="rainbow") | |
new_email_id = st.text_input("email_id", placeholder="Enter an Email ID") | |
new_password = st.text_input( | |
"New Password", type="password", placeholder="Choose a password" | |
) | |
confirm_password = st.text_input( | |
"Confirm Password", | |
type="password", | |
placeholder="Re-enter your password", | |
) | |
if st.button("Register"): | |
if new_password != confirm_password: | |
st.error("Passwords do not match. Please try again.") | |
elif not new_email_id or not new_password: | |
st.error("All fields are required.") | |
else: | |
response = mongo_db.register(new_email_id, new_password) | |
print(response) | |
if "Email ID already Registered" == response: | |
st.error( | |
"Email ID already Registered. Please choose a different email id." | |
) | |
elif "Registration successful" == response: | |
st.success("Registration successful! You can now log in.") | |
st.balloons() | |
else: | |
st.error("Unexpected error occur. Please Retry....") | |
if st.button("Register Using Google Account"): | |
user_info = google_auth.auth() | |
user_name = f'{user_info.get("user").get("displayName")}' | |
email = f'{user_info.get("user").get("emailAddress")}' | |
response = mongo_db.google_register(user_name, email) | |
if "email ID already Registered" == response: | |
st.error("Email ID already Registered. Please Log in.") | |
elif "Registration successful" == response: | |
st.success("Registration successful! You can now log in.") | |
st.balloons() | |
else: | |
st.error("Unexpected error occur. Please Retry....") | |
# Function to handle login | |
def login(): | |
col1, col2, col3 = st.columns([1, 2, 1]) # Left, Center (main content), Right | |
with col2: | |
_, centre, _ = st.columns([1, 2, 1]) | |
with centre: | |
st.image("ImagiGen--Logo.svg", use_container_width=True) | |
st.header("Login Page", divider="rainbow") | |
st.subheader("Please Login to Continue") | |
email_id = st.text_input("Email Id", placeholder="Enter your Email Id") | |
password = st.text_input( | |
"Password", type="password", placeholder="Enter your password" | |
) | |
if st.button("Login"): | |
response = mongo_db.login(email_id, password) | |
if response == "Login successful": | |
st.session_state.logged_in = True | |
st.success("Login successful! Redirecting...") | |
st.rerun() | |
else: | |
st.error("Invalid email id or password") | |
if st.button("google"): | |
user_info = google_auth.auth() | |
email = f'{user_info.get("user").get("emailAddress")}' | |
response = mongo_db.google_login(email) | |
if response == "Login successful": | |
st.session_state.logged_in = True | |
st.rerun() | |
else: | |
st.error("Email id is not registered. Please Register") | |
# Function for the platform/dashboard page | |
def platform_page(): | |
_, centre, _ = st.columns([3, 4, 3]) | |
with centre: | |
st.image("ImagiGen--Logo.svg", use_container_width=True) | |
st.header("Welcome to ImagiGen: AI-Powered Image Synthesis", divider="rainbow") | |
if st.button("Logout"): | |
st.session_state.logged_in = False | |
st.rerun() | |
st.markdown("### Describe Your Creative Vision") | |
# Initialize session state for prompts | |
if "model_1" not in st.session_state: | |
st.session_state.model_1 = "" | |
if "model_2" not in st.session_state: | |
st.session_state.model_2 = "" | |
if "model_3" not in st.session_state: | |
st.session_state.model_3 = "" | |
# User input for the prompt | |
Raw_user_input = st.text_input( | |
"Enter your description for prompt generation:", | |
"", | |
placeholder="E.g., A vivid image of a flowing red silk dress in a windy desert...", | |
) | |
user_input = f"""Instruction: Generate a highly detailed and vivid prompt based on the user input, specifically for the art and design industry, with a focus on clothing... User Input: {Raw_user_input}""" | |
if st.button("Generate Prompt"): | |
with st.spinner("Generating Prompts..."): | |
st.session_state.model_1 = Qwen_72b(user_input) | |
st.session_state.model_2 = microsoft_phi(user_input) | |
st.session_state.model_3 = Mixtral(user_input) | |
# Display text areas to edit each generated prompt | |
cols = st.columns(3) | |
with cols[0]: | |
st.session_state.model_1 = st.text_area( | |
"Edit Prompt 1:", st.session_state.model_1, height=200 | |
) | |
with cols[1]: | |
st.session_state.model_2 = st.text_area( | |
"Edit Prompt 2:", st.session_state.model_2, height=200 | |
) | |
with cols[2]: | |
st.session_state.model_3 = st.text_area( | |
"Edit Prompt 3:", st.session_state.model_3, height=200 | |
) | |
if st.button("Generate Images"): | |
prompts = [ | |
st.session_state.model_1, | |
st.session_state.model_2, | |
st.session_state.model_3, | |
] | |
cols = st.columns(3) | |
with cols[0]: | |
with st.spinner(f"Generating Image {1}..."): | |
try: | |
image_bytes = image_functions[0](prompts[0]) | |
image = Image.open(io.BytesIO(image_bytes)) | |
st.image( | |
image, | |
caption=f"Prompt {1}, Model {1}", | |
use_container_width=True, | |
) | |
except: | |
st.image( | |
"error.jpg", | |
caption="Server error occur", | |
) | |
with cols[1]: | |
with st.spinner(f"Generating Image {2}..."): | |
try: | |
image_bytes = image_functions[0](prompts[1]) | |
image = Image.open(io.BytesIO(image_bytes)) | |
st.image( | |
image, | |
caption=f"Prompt {2}, Model {1}", | |
use_container_width=True, | |
) | |
except: | |
st.image( | |
"error.jpg", | |
caption="Server error occur", | |
) | |
with cols[2]: | |
with st.spinner(f"Generating Image {3}..."): | |
try: | |
image_bytes = image_functions[0](prompts[2]) | |
image = Image.open(io.BytesIO(image_bytes)) | |
st.image( | |
image, | |
caption=f"Prompt {3}, Model {1}", | |
use_container_width=True, | |
) | |
except: | |
st.image( | |
"error.jpg", | |
caption="Server error occur", | |
) | |
with cols[0]: | |
with st.spinner(f"Generating Image {4}..."): | |
try: | |
image_bytes = image_functions[1](prompts[0]) | |
image = Image.open(io.BytesIO(image_bytes)) | |
st.image( | |
image, | |
caption=f"Prompt {1}, Model {2}", | |
use_container_width=True, | |
) | |
except: | |
st.image( | |
"error.jpg", | |
caption="Server error occur", | |
) | |
with cols[1]: | |
with st.spinner(f"Generating Image {5}..."): | |
try: | |
image_bytes = image_functions[1](prompts[2]) | |
image = Image.open(io.BytesIO(image_bytes)) | |
st.image( | |
image, | |
caption=f"Prompt {2}, Model {2}", | |
use_container_width=True, | |
) | |
except: | |
st.image( | |
"error.jpg", | |
caption="Server error occur", | |
) | |
with cols[2]: | |
with st.spinner(f"Generating Image {6}..."): | |
try: | |
image_bytes = image_functions[1](prompts[3]) | |
image = Image.open(io.BytesIO(image_bytes)) | |
st.image( | |
image, | |
caption=f"Prompt {3}, Model {2}", | |
use_container_width=True, | |
) | |
except: | |
st.image( | |
"error.jpg", | |
caption="Server error occur", | |
) | |
if st.button("Clear/Regenerate"): | |
st.session_state.model_1 = "" | |
st.session_state.model_2 = "" | |
st.session_state.model_3 = "" | |
st.rerun() | |
# Main logic to control which page to display | |
if st.session_state.logged_in: | |
platform_page() | |
else: | |
login() | |
register() | |