BizIntel_AI / app.py
mgbam's picture
Update app.py
a7d25a1 verified
# app.py – BizIntel AI Ultra v2.1
# =============================================================
# • Upload CSV / Excel • SQL–DB fetch • Trend + ARIMA forecast
# • Model explainability (summary, coef interp, ACF, back-test)
# • Gemini 1.5 Pro strategy generation
# • Optional EDA visuals • Safe Plotly PNG write to /tmp
# =============================================================
import os
import tempfile
import warnings
from typing import List, Tuple
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.graphics.tsaplots import plot_acf
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tools.sm_exceptions import ConvergenceWarning
import google.generativeai as genai
# ──────────────────────────────────────────────────────────────
# Local helper modules
# ──────────────────────────────────────────────────────────────
from tools.csv_parser import parse_csv_tool
from tools.plot_generator import plot_metric_tool
from tools.forecaster import forecast_metric_tool # only for png path if needed
from tools.visuals import (
histogram_tool, scatter_matrix_tool, corr_heatmap_tool
)
from db_connector import fetch_data_from_db, list_tables, SUPPORTED_ENGINES
# ──────────────────────────────────────────────────────────────
# Plotly safe write — ensure PNGs go to writable /tmp
# ──────────────────────────────────────────────────────────────
TMP = tempfile.gettempdir()
orig_write = go.Figure.write_image
go.Figure.write_image = lambda self, p, *a, **k: orig_write(
self, os.path.join(TMP, os.path.basename(p)), *a, **k
)
# ──────────────────────────────────────────────────────────────
# Gemini 1.5 Pro setup
# ──────────────────────────────────────────────────────────────
genai.configure(api_key=os.getenv("GEMINI_APIKEY"))
gemini = genai.GenerativeModel(
"gemini-1.5-pro-latest",
generation_config=dict(temperature=0.7, top_p=0.9, response_mime_type="text/plain"),
)
# ──────────────────────────────────────────────────────────────
# Streamlit layout
# ──────────────────────────────────────────────────────────────
st.set_page_config(page_title="BizIntel AI Ultra", layout="wide")
st.title("📊 BizIntel AI Ultra – Advanced Analytics + Gemini 1.5 Pro")
# ──────────────────────────────────────────────────────────────
# 1) Data source selection
# ──────────────────────────────────────────────────────────────
choice = st.radio("Select data source", ["Upload CSV / Excel", "Connect to SQL Database"])
csv_path: str | None = None
if choice.startswith("Upload"):
up = st.file_uploader("CSV or Excel (≤ 500 MB)", type=["csv", "xlsx", "xls"])
if up:
tmp = os.path.join(TMP, up.name)
with open(tmp, "wb") as f:
f.write(up.read())
if up.name.lower().endswith(".csv"):
csv_path = tmp
else:
try:
pd.read_excel(tmp).to_csv(tmp + ".csv", index=False)
csv_path = tmp + ".csv"
except Exception as e:
st.error(f"Excel parse failed: {e}")
else:
eng = st.selectbox("DB engine", SUPPORTED_ENGINES, key="db_eng")
conn = st.text_input("SQLAlchemy connection string")
if conn:
try:
tbl = st.selectbox("Table", list_tables(conn))
if st.button("Fetch table"):
csv_path = fetch_data_from_db(conn, tbl)
st.success(f"Fetched **{tbl}**")
except Exception as e:
st.error(f"DB error: {e}")
if not csv_path:
st.stop()
with open(csv_path, "rb") as f:
st.download_button("⬇️ Download working CSV", f, file_name=os.path.basename(csv_path))
# ──────────────────────────────────────────────────────────────
# 2) Column pickers
# ──────────────────────────────────────────────────────────────
df_head = pd.read_csv(csv_path, nrows=5)
st.dataframe(df_head)
date_col = st.selectbox("Date/time column", df_head.columns)
numeric_df = df_head.select_dtypes("number")
metric_col = st.selectbox(
"Numeric metric column",
[c for c in numeric_df.columns if c != date_col] or numeric_df.columns
)
if metric_col is None:
st.warning("Need at least one numeric column.")
st.stop()
# ──────────────────────────────────────────────────────────────
# 3) Quick data summary & trend chart
# ──────────────────────────────────────────────────────────────
summary_md = parse_csv_tool(csv_path)
trend_res = plot_metric_tool(csv_path, date_col, metric_col)
if isinstance(trend_res, tuple):
trend_fig, _ = trend_res
elif isinstance(trend_res, go.Figure):
trend_fig = trend_res
else: # error message str
st.warning(trend_res)
trend_fig = None
if trend_fig is not None:
st.subheader("📈 Trend")
st.plotly_chart(trend_fig, use_container_width=True)
# ──────────────────────────────────────────────────────────────
# 4) Build clean series & ARIMA helpers
# ──────────────────────────────────────────────────────────────
@st.cache_data(show_spinner="Preparing series…")
def build_series(path, dcol, vcol):
df = pd.read_csv(path, usecols=[dcol, vcol])
df[dcol] = pd.to_datetime(df[dcol], errors="coerce")
df[vcol] = pd.to_numeric(df[vcol], errors="coerce")
df = df.dropna(subset=[dcol, vcol]).sort_values(dcol)
if df.empty:
raise ValueError("Not enough valid data.")
s = df.set_index(dcol)[vcol].groupby(level=0).mean().sort_index()
freq = pd.infer_freq(s.index) or "D"
s = s.asfreq(freq).interpolate()
return s, freq
@st.cache_data(show_spinner="Fitting ARIMA…")
def fit_arima(series):
warnings.simplefilter("ignore", ConvergenceWarning)
return ARIMA(series, order=(1, 1, 1)).fit()
try:
series, freq = build_series(csv_path, date_col, metric_col)
horizon = 90 if freq == "D" else 3
model_res = fit_arima(series)
fc_obj = model_res.get_forecast(horizon)
forecast = fc_obj.predicted_mean
ci = fc_obj.conf_int()
except Exception as e:
st.subheader(f"🔮 {metric_col} Forecast")
st.warning(f"Forecast failed: {e}")
forecast = ci = model_res = None
# ──────────────────────────────────────────────────────────────
# 5) Forecast plot & explainability
# ──────────────────────────────────────────────────────────────
if forecast is not None:
fig = go.Figure()
fig.add_scatter(x=series.index, y=series, mode="lines", name=metric_col)
fig.add_scatter(x=forecast.index, y=forecast, mode="lines+markers", name="Forecast")
fig.add_scatter(
x=ci.index, y=ci.iloc[:, 1], mode="lines", line=dict(width=0), showlegend=False
)
fig.add_scatter(
x=ci.index,
y=ci.iloc[:, 0],
mode="lines",
line=dict(width=0),
fill="tonexty",
fillcolor="rgba(255,0,0,0.25)",
showlegend=False,
)
fig.update_layout(
title=f"{metric_col} Forecast ({horizon} steps)",
xaxis_title=date_col,
yaxis_title=metric_col,
template="plotly_dark",
)
st.subheader(f"🔮 {metric_col} Forecast")
st.plotly_chart(fig, use_container_width=True)
# -- model summary -----------------------------------------------------
st.subheader("📄 ARIMA Model Summary")
st.code(model_res.summary().as_text())
# -- coefficient interpretation ---------------------------------------
ar, ma = model_res.arparams, model_res.maparams
interp = []
if ar.size:
interp.append(
f"• AR(1) ={ar[0]:.2f} → "
f"{'strong' if abs(ar[0]) > 0.5 else 'moderate'} persistence."
)
if ma.size:
interp.append(
f"• MA(1) ={ma[0]:.2f} → "
f"{'large' if abs(ma[0]) > 0.5 else 'modest'} shock adjustment."
)
st.subheader("🗒 Coefficient Interpretation")
st.markdown("\n".join(interp) or "N/A")
# -- residual ACF ------------------------------------------------------
st.subheader("🔍 Residual ACF")
acf_png = os.path.join(TMP, "acf.png")
plot_acf(model_res.resid.dropna(), lags=30, alpha=0.05)
import matplotlib.pyplot as plt
plt.tight_layout()
plt.savefig(acf_png, dpi=120)
plt.close()
st.image(acf_png, use_container_width=True)
# -- back-test ---------------------------------------------------------
k = max(int(len(series) * 0.2), 10)
train, test = series[:-k], series[-k:]
bt_res = ARIMA(train, order=(1, 1, 1)).fit()
bt_pred = bt_res.forecast(k)
mape = (abs(bt_pred - test) / test).mean() * 100
rmse = np.sqrt(((bt_pred - test) ** 2).mean())
st.subheader("🧪 Back-test (last 20 %)")
col1, col2 = st.columns(2)
col1.metric("MAPE", f"{mape:.2f}%")
col2.metric("RMSE", f"{rmse:,.0f}")
# -- seasonal decomposition (optional) --------------------------------
with st.expander("Seasonal Decomposition"):
try:
period = {"D": 7, "H": 24, "M": 12}.get(freq)
if period:
dec = seasonal_decompose(series, period=period, model="additive")
for comp in ["trend", "seasonal", "resid"]:
st.line_chart(getattr(dec, comp).dropna(), height=150)
else:
st.info("Frequency not suited for decomposition.")
except Exception as e:
st.info(f"Decomposition failed: {e}")
# ──────────────────────────────────────────────────────────────
# 6) Gemini strategy report
# ──────────────────────────────────────────────────────────────
prompt = (
"You are **BizIntel Strategist AI**.\n\n"
f"### Dataset Summary\n```\n{summary_md}\n```\n\n"
f"### {metric_col} Forecast\n```\n"
f"{forecast.to_string() if forecast is not None else 'N/A'}\n```"
"\nGenerate a Markdown report with:\n"
"• 5 insights\n• 3 actionable strategies\n• Risks / anomalies\n• Additional visuals."
)
with st.spinner("Gemini 1.5 Pro is thinking…"):
md = gemini.generate_content(prompt).text
st.subheader("🚀 Strategy Recommendations (Gemini 1.5 Pro)")
st.markdown(md)
st.download_button("⬇️ Download Strategy (.md)", md, file_name="strategy.md")
# ──────────────────────────────────────────────────────────────
# 7) High-level dataset KPIs + optional EDA
# ──────────────────────────────────────────────────────────────
fulldf = pd.read_csv(csv_path, low_memory=False)
rows, cols = fulldf.shape
miss_pct = fulldf.isna().mean().mean() * 100
st.markdown("---")
st.subheader("📑 Dataset KPIs")
k1, k2, k3 = st.columns(3)
k1.metric("Rows", f"{rows:,}")
k2.metric("Columns", cols)
k3.metric("Missing %", f"{miss_pct:.1f}%")
with st.expander("Descriptive Statistics (numeric)"):
st.dataframe(
fulldf.describe().T.round(2).style.format(precision=2).background_gradient("Blues"),
use_container_width=True,
)
st.markdown("---")
st.subheader("🔍 Optional EDA Visuals")
if st.checkbox("Histogram"):
col = st.selectbox("Variable", fulldf.select_dtypes("number").columns)
hr = histogram_tool(csv_path, col)
if isinstance(hr, tuple):
st.plotly_chart(hr[0], use_container_width=True)
else:
st.warning(hr)
if st.checkbox("Scatter Matrix"):
opts = fulldf.select_dtypes("number").columns.tolist()
sel = st.multiselect("Columns", opts, default=opts[:3])
if sel:
sm = scatter_matrix_tool(csv_path, sel)
if isinstance(sm, tuple):
st.plotly_chart(sm[0], use_container_width=True)
else:
st.warning(sm)
if st.checkbox("Correlation Heat-map"):
hm = corr_heatmap_tool(csv_path)
if isinstance(hm, tuple):
st.plotly_chart(hm[0], use_container_width=True)
else:
st.warning(hm)