MOSPI_analysis_tool / table_analysis_for_pdf.py
akshansh36's picture
Upload 10 files
eef9e83 verified
import pandas as pd
from io import BytesIO
import requests
import streamlit as st
from pymongo import MongoClient
import os
from dotenv import load_dotenv
import json
from pygwalker.api.streamlit import StreamlitRenderer
# Load environment variables
load_dotenv()
MONGO_URI = os.getenv("MONGO_URI")
DB_NAME = os.getenv("DB_NAME")
COLLECTION_NAME = os.getenv("COLLECTION_NAME")
mongo_client = MongoClient(MONGO_URI)
db = mongo_client[DB_NAME]
collection = db[COLLECTION_NAME]
# Load CSV from S3 URL
def load_csv_from_url(object_url):
response = requests.get(object_url)
response.raise_for_status()
return pd.read_csv(BytesIO(response.content))
# Analyze column data
def analyze_column_data(df):
analysis = {}
for col in df.columns:
if pd.api.types.is_numeric_dtype(df[col]):
analysis[col] = {
"Mean": df[col].mean(),
"Median": df[col].median(),
"Mode": df[col].mode()[0] if not df[col].mode().empty else None,
"Unique Values": df[col].nunique(),
"Null Values": df[col].isnull().sum()
}
else:
analysis[col] = {
"Unique Values": df[col].nunique(),
"Null Values": df[col].isnull().sum(),
"Top Categories": df[col].value_counts().head(5).to_dict()
}
return analysis
# Display analysis for a selected table
def display_table_analysis(table):
# Load CSV data
df = load_csv_from_url(table['csv_object_url'])
# Check for "total" row
if df.iloc[-1].astype(str).str.contains("total", case=False).any():
df = df.iloc[:-1] # Drop last row if "total" found
# Table preview
st.subheader("CSV Preview")
st.dataframe(df, height=300)
# Download Button
st.download_button(
label="Download CSV",
data=requests.get(table['csv_object_url']).content,
file_name="table_data.csv",
mime="text/csv"
)
# Table Description
if 'description' in table:
st.subheader("Table Description")
st.write(table['description'])
# Column Summary
st.subheader("Column Summary")
column_summary = table.get('column_summary', {})
column_analysis = analyze_column_data(df)
col1, col2 = st.columns(2)
for idx, (col_name, col_description) in enumerate(column_summary.items()):
with col1 if idx % 2 == 0 else col2:
st.markdown(f"Column Name: **{col_name}**")
st.write(f"Description: {col_description}")
analysis = column_analysis.get(col_name, {})
if pd.api.types.is_numeric_dtype(df[col_name]):
st.write({
"Mean": analysis.get("Mean"),
"Median": analysis.get("Median"),
"Mode": analysis.get("Mode"),
"Unique Values": analysis.get("Unique Values"),
"Null Values": analysis.get("Null Values")
})
else:
st.write({
"Unique Values": analysis.get("Unique Values"),
"Null Values": analysis.get("Null Values"),
"Top Categories": analysis.get("Top Categories")
})
# Graphical Analysis using Pygwalker
st.subheader("Graphical Analysis of Table")
pyg_app = StreamlitRenderer(df)
pyg_app.explorer()
# Main function to render the View Table Analysis page for PDF tables
def view_pdf_table_analysis_page(url):
if st.button("Back", key="back_button"):
st.session_state.page = "view_pdf"
st.rerun()
# Retrieve table data for the PDF
pdf_data = collection.find_one({"object_url": url})
tables = pdf_data.get("table_data", [])
# Display the total number of tables
st.title("PDF Table Analysis")
st.write(f"Total tables found: {len(tables)}")
if "selected_table" not in st.session_state or st.session_state.selected_table is None or st.session_state.selected_table >= len(tables):
st.session_state.selected_table = 0
selected_table_idx = st.radio(
"Select a table to analyze",
options=range(len(tables)),
format_func=lambda x: f"Analyze Table {x + 1}",
index=st.session_state.selected_table # Safely use the default if uninitialized
)
st.session_state.selected_table = selected_table_idx
if st.session_state.selected_table is not None:
selected_table_data = tables[st.session_state.selected_table]
st.subheader(f"Analysis for Table {st.session_state.selected_table + 1}")
csv_url = selected_table_data['csv_object_url']
df = load_csv_from_url(csv_url)
if df.iloc[-1].apply(lambda x: "total" in str(x).lower()).any():
df = df.iloc[:-1]
st.dataframe(df) # Interactive, scrollable table
excel_buffer = BytesIO()
with pd.ExcelWriter(excel_buffer, engine='openpyxl') as writer:
df.to_excel(writer, index=False, sheet_name="Sheet1")
excel_buffer.seek(0) # Reset buffer position
# Download Button
st.download_button(
label="Download Full Excel Sheet",
data=excel_buffer,
file_name="table_data.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
st.markdown("<hr>", unsafe_allow_html=True)
table_description = selected_table_data.get("description", None)
if table_description:
# Table Description
st.subheader("Table Description")
st.write(table_description)
# Column Summary
st.markdown("<hr>", unsafe_allow_html=True)
st.subheader("Column Summary")
with st.container(height=400, border=False):
column_summary = selected_table_data.get("column_summary", None)
if column_summary:
# Column-level descriptions and analysis
column_analysis = analyze_column_data(df)
col1, col2 = st.columns(2)
for idx, (col_name, col_description) in enumerate(column_summary.items()):
# Determine which column to use based on the index
with col1 if idx % 2 == 0 else col2:
st.markdown(f"Column Name : **{col_name}**")
st.write(f"Column Description : {col_description}")
# Display basic analysis
analysis = column_analysis.get(col_name, {})
if pd.api.types.is_numeric_dtype(df[col_name]):
# Numeric column analysis
st.write({
"Mean": analysis.get("Mean"),
"Median": analysis.get("Median"),
"Mode": analysis.get("Mode"),
"Unique Values": analysis.get("Unique Values"),
"Null Values": analysis.get("Null Values")
})
else:
# Categorical column analysis
st.write({
"Unique Values": analysis.get("Unique Values"),
"Null Values": analysis.get("Null Values"),
"Top Categories": analysis.get("Top Categories")
})
st.markdown("<hr>", unsafe_allow_html=True)
st.subheader("Graphical Analysis of Table")
best_col1 = selected_table_data.get("best_col1")
best_col2 = selected_table_data .get("best_col2")
default_chart_config = {
"mark": "bar",
"encoding": {
"x": {"field": best_col1, "type": "nominal"},
"y": {"field": best_col2, "type": "quantitative"}
}
}
# Convert default_chart_config to JSON string for Pygwalker spec parameter
pyg_app = StreamlitRenderer(df, spec=json.dumps(default_chart_config))
pyg_app.explorer()