Spaces:
Sleeping
Sleeping
File size: 5,287 Bytes
5f8b3ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
import os
import psycopg2 as pgsql
import pandas as pd
import plotly.express as px
from dotenv import load_dotenv
import google.generativeai as genai
# Load environment variables
load_dotenv()
# Configure Genai Key
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Function to load Google Gemini Model and provide queries as response
def get_gemini_response(question, prompt):
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content([prompt[0], question])
return response.text.strip()
# Function to retrieve query from the database
def read_sql_query(sql, db_params):
try:
conn = pgsql.connect(**db_params)
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
colnames = [desc[0] for desc in cur.description] if cur.description else []
conn.commit()
cur.close()
conn.close()
df = pd.DataFrame(rows, columns=colnames)
# Convert 'price' column to numeric if it exists
if 'price' in df.columns:
df['price'] = pd.to_numeric(df['price'], errors='coerce')
return df
except Exception as e:
st.error(f"An error occurred: {e}")
return pd.DataFrame()
# Define your PostgreSQL connection parameters
db_params = {
'dbname': 'GeminiPro',
'user': 'postgres',
'password': 'root',
'host': 'localhost',
'port': 5432
}
# Define Your Prompt
prompt = [
"""
You are an expert in converting English questions to SQL queries!
The SQL database has a table named 'department_store' with the following columns:
id, product_name, category, price, stock_quantity, supplier, last_restock_date.
Examples:
- How many products do we have in total?
The SQL command will be: SELECT COUNT(*) FROM department_store;
- What are all the products in the Electronics category?
The SQL command will be: SELECT * FROM department_store WHERE category = 'Electronics';
The SQL code should not include backticks and should not start with the word 'SQL'.
"""
]
# Streamlit App
st.set_page_config(page_title="AutomatiX - Department Store Analytics", layout="wide")
# Sidebar for user input
st.sidebar.title("AutomatiX - Department Store Chat Interface")
question = st.sidebar.text_area("Enter your question:", key="input")
submit = st.sidebar.button("Ask Me")
# Main content area
st.title("AutomatiX - Department Store Dashboard")
if submit:
with st.spinner("Generating and fetching data..."):
sql_query = get_gemini_response(question, prompt)
# st.code(sql_query, language="sql")
df = read_sql_query(sql_query, db_params)
if not df.empty:
st.success("Query executed successfully!")
# Display data in a table
st.subheader("Data Table")
st.dataframe(df)
# Create visualizations based on the data
st.subheader("Data Visualizations")
col1, col2 = st.columns(2)
with col1:
if 'price' in df.columns and df['price'].notna().any():
fig = px.histogram(df, x='price', title='Price Distribution')
st.plotly_chart(fig, use_container_width=True)
if 'category' in df.columns:
category_counts = df['category'].value_counts()
fig = px.pie(values=category_counts.values, names=category_counts.index, title='Products by Category')
st.plotly_chart(fig, use_container_width=True)
with col2:
if 'last_restock_date' in df.columns:
df['last_restock_date'] = pd.to_datetime(df['last_restock_date'], errors='coerce')
df['restock_month'] = df['last_restock_date'].dt.to_period('M')
restock_counts = df['restock_month'].value_counts().sort_index()
fig = px.line(x=restock_counts.index.astype(str), y=restock_counts.values, title='Restocking Trend')
st.plotly_chart(fig, use_container_width=True)
if 'product_name' in df.columns and 'price' in df.columns and df['price'].notna().any():
top_prices = df.sort_values('price', ascending=False).head(10)
fig = px.bar(top_prices, x='product_name', y='price', title='Top 10 Most Expensive Products')
st.plotly_chart(fig, use_container_width=True)
else:
st.warning("No data returned from the query.")
else:
st.info("Enter a question and click 'Ask Me' to get started!")
# Footer
st.sidebar.markdown("---")
st.sidebar.info("You can ask questions like:\n"
"1.What are all the products in the Electronics category?\n"
"2.What is the average price of products in each category?\n"
"3.Which products have a stock quantity less than 30?\n"
"4.What are the top 5 most expensive products?")
st.sidebar.warning("CopyRights@AutomatiX - Powered by Streamlit and Google Gemini") |