ImagiGen_v2 / app.py
Tejasva-Maurya's picture
Update app.py
06b3f66 verified
raw
history blame
11.7 kB
import streamlit as st
from PIL import Image
import io
import google_auth
import mongo_db
# Functions for image and prompt generation
from image_models import (
flux,
stable_diffusion,
)
from prompt_models import Qwen_72b, microsoft_phi, Mixtral
# Set page config with a custom title and layout
st.set_page_config(
page_title="Welcome to ImagiGen: AI-Powered Image Synthesis",
layout="wide",
initial_sidebar_state="expanded",
)
# Custom CSS for fonts, background, and animations
st.markdown(
"""
<style>
/* Add a custom font */
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
body {
font-family: 'Poppins', sans-serif;
background: linear-gradient(135deg, #f8f9fa, #e9ecef);
animation: gradient-animation 10s ease infinite;
}
/* Header and title styles */
h1, h2, h3 {
color: #343a40;
text-align: center;
}
/* Button styles */
.stButton button {
background-color: #007bff;
color: white;
border-radius: 8px;
transition: all 0.3s ease;
}
.stButton button:hover {
background-color: #0056b3;
transform: scale(1.05);
}
/* Smooth gradient animation */
@keyframes gradient-animation {
0% { background: linear-gradient(135deg, #f8f9fa, #e9ecef); }
50% { background: linear-gradient(135deg, #e9ecef, #f8f9fa); }
100% { background: linear-gradient(135deg, #f8f9fa, #e9ecef); }
}
/* Custom input box styling */
input, textarea {
border: 2px solid #ced4da;
border-radius: 5px;
padding: 10px;
}
/* Sidebar customization */
.sidebar .sidebar-content {
background-color: #ffffff;
border-radius: 10px;
padding: 20px;
}
</style>
""",
unsafe_allow_html=True,
)
# List of functions
image_functions = [flux, stable_diffusion]
prompt_functions = [Qwen_72b, microsoft_phi, Mixtral]
# Initialize session state for login status
if "logged_in" not in st.session_state:
st.session_state.logged_in = False
def register():
col1, col2, col3 = st.columns([1, 2, 1]) # Left, Center (main content), Right
with col2:
with st.expander("Register New Account", expanded=False):
st.subheader("Create a New Account", divider="rainbow")
new_email_id = st.text_input("email_id", placeholder="Enter an Email ID")
new_password = st.text_input(
"New Password", type="password", placeholder="Choose a password"
)
confirm_password = st.text_input(
"Confirm Password",
type="password",
placeholder="Re-enter your password",
)
if st.button("Register"):
if new_password != confirm_password:
st.error("Passwords do not match. Please try again.")
elif not new_email_id or not new_password:
st.error("All fields are required.")
else:
response = mongo_db.register(new_email_id, new_password)
print(response)
if "Email ID already Registered" == response:
st.error(
"Email ID already Registered. Please choose a different email id."
)
elif "Registration successful" == response:
st.success("Registration successful! You can now log in.")
st.balloons()
else:
st.error("Unexpected error occur. Please Retry....")
if st.button("Register Using Google Account"):
user_info = google_auth.auth()
user_name = f'{user_info.get("user").get("displayName")}'
email = f'{user_info.get("user").get("emailAddress")}'
response = mongo_db.google_register(user_name, email)
if "email ID already Registered" == response:
st.error("Email ID already Registered. Please Log in.")
elif "Registration successful" == response:
st.success("Registration successful! You can now log in.")
st.balloons()
else:
st.error("Unexpected error occur. Please Retry....")
# Function to handle login
def login():
col1, col2, col3 = st.columns([1, 2, 1]) # Left, Center (main content), Right
with col2:
_, centre, _ = st.columns([1, 2, 1])
with centre:
st.image("ImagiGen--Logo.svg", use_container_width=True)
st.header("Login Page", divider="rainbow")
st.subheader("Please Login to Continue")
email_id = st.text_input("Email Id", placeholder="Enter your Email Id")
password = st.text_input(
"Password", type="password", placeholder="Enter your password"
)
if st.button("Login"):
response = mongo_db.login(email_id, password)
if response == "Login successful":
st.session_state.logged_in = True
st.success("Login successful! Redirecting...")
st.rerun()
else:
st.error("Invalid email id or password")
if st.button("google"):
user_info = google_auth.auth()
email = f'{user_info.get("user").get("emailAddress")}'
response = mongo_db.google_login(email)
if response == "Login successful":
st.session_state.logged_in = True
st.rerun()
else:
st.error("Email id is not registered. Please Register")
# Function for the platform/dashboard page
def platform_page():
_, centre, _ = st.columns([3, 4, 3])
with centre:
st.image("ImagiGen--Logo.svg", use_container_width=True)
st.header("Welcome to ImagiGen: AI-Powered Image Synthesis", divider="rainbow")
if st.button("Logout"):
st.session_state.logged_in = False
st.rerun()
st.markdown("### Describe Your Creative Vision")
# Initialize session state for prompts
if "model_1" not in st.session_state:
st.session_state.model_1 = ""
if "model_2" not in st.session_state:
st.session_state.model_2 = ""
if "model_3" not in st.session_state:
st.session_state.model_3 = ""
# User input for the prompt
Raw_user_input = st.text_input(
"Enter your description for prompt generation:",
"",
placeholder="E.g., A vivid image of a flowing red silk dress in a windy desert...",
)
user_input = f"""Instruction: Generate a highly detailed and vivid prompt based on the user input, specifically for the art and design industry, with a focus on clothing... User Input: {Raw_user_input}"""
if st.button("Generate Prompt"):
with st.spinner("Generating Prompts..."):
st.session_state.model_1 = Qwen_72b(user_input)
st.session_state.model_2 = microsoft_phi(user_input)
st.session_state.model_3 = Mixtral(user_input)
# Display text areas to edit each generated prompt
cols = st.columns(3)
with cols[0]:
st.session_state.model_1 = st.text_area(
"Edit Prompt 1:", st.session_state.model_1, height=200
)
with cols[1]:
st.session_state.model_2 = st.text_area(
"Edit Prompt 2:", st.session_state.model_2, height=200
)
with cols[2]:
st.session_state.model_3 = st.text_area(
"Edit Prompt 3:", st.session_state.model_3, height=200
)
if st.button("Generate Images"):
prompts = [
st.session_state.model_1,
st.session_state.model_2,
st.session_state.model_3,
]
cols = st.columns(3)
with cols[0]:
with st.spinner(f"Generating Image {1}..."):
try:
image_bytes = image_functions[0](prompts[0])
image = Image.open(io.BytesIO(image_bytes))
st.image(
image,
caption=f"Prompt {1}, Model {1}",
use_container_width=True,
)
except:
st.image(
"error.jpg",
caption="Server error occur",
)
with cols[1]:
with st.spinner(f"Generating Image {2}..."):
try:
image_bytes = image_functions[0](prompts[1])
image = Image.open(io.BytesIO(image_bytes))
st.image(
image,
caption=f"Prompt {2}, Model {1}",
use_container_width=True,
)
except:
st.image(
"error.jpg",
caption="Server error occur",
)
with cols[2]:
with st.spinner(f"Generating Image {3}..."):
try:
image_bytes = image_functions[0](prompts[2])
image = Image.open(io.BytesIO(image_bytes))
st.image(
image,
caption=f"Prompt {3}, Model {1}",
use_container_width=True,
)
except:
st.image(
"error.jpg",
caption="Server error occur",
)
with cols[0]:
with st.spinner(f"Generating Image {4}..."):
try:
image_bytes = image_functions[1](prompts[0])
image = Image.open(io.BytesIO(image_bytes))
st.image(
image,
caption=f"Prompt {1}, Model {2}",
use_container_width=True,
)
except:
st.image(
"error.jpg",
caption="Server error occur",
)
with cols[1]:
with st.spinner(f"Generating Image {5}..."):
try:
image_bytes = image_functions[1](prompts[2])
image = Image.open(io.BytesIO(image_bytes))
st.image(
image,
caption=f"Prompt {2}, Model {2}",
use_container_width=True,
)
except:
st.image(
"error.jpg",
caption="Server error occur",
)
with cols[2]:
with st.spinner(f"Generating Image {6}..."):
try:
image_bytes = image_functions[1](prompts[3])
image = Image.open(io.BytesIO(image_bytes))
st.image(
image,
caption=f"Prompt {3}, Model {2}",
use_container_width=True,
)
except:
st.image(
"error.jpg",
caption="Server error occur",
)
if st.button("Clear/Regenerate"):
st.session_state.model_1 = ""
st.session_state.model_2 = ""
st.session_state.model_3 = ""
st.rerun()
# Main logic to control which page to display
if st.session_state.logged_in:
platform_page()
else:
login()
register()