BizIntel_AI / tools /visuals.py
mgbam's picture
Update tools/visuals.py
df65c2e verified
# tools/visuals.py — reusable Plotly helpers
# ------------------------------------------------------------
import os
import tempfile
from typing import List, Tuple, Union
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy.cluster.hierarchy import linkage, leaves_list
# -----------------------------------------------------------------
# Typing alias: every helper returns a plotly.graph_objects.Figure
# -----------------------------------------------------------------
Plot = go.Figure
# -----------------------------------------------------------------
# Utility: save figure to high‑res PNG under a writable dir (/tmp)
# -----------------------------------------------------------------
def _save_fig(fig: Plot, prefix: str, outdir: str = "/tmp") -> str:
os.makedirs(outdir, exist_ok=True)
tmp = tempfile.NamedTemporaryFile(
prefix=prefix, suffix=".png", dir=outdir, delete=False
)
fig.write_image(tmp.name, scale=3)
return tmp.name
# -----------------------------------------------------------------
# 1) Histogram (+ optional KDE)
# -----------------------------------------------------------------
def histogram_tool(
file_path: str,
column: str,
bins: int = 30,
kde: bool = True,
output_dir: str = "/tmp",
) -> Union[Tuple[Plot, str], str]:
ext = os.path.splitext(file_path)[1].lower()
df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
if column not in df.columns:
return f"❌ Column '{column}' not found."
series = pd.to_numeric(df[column], errors="coerce").dropna()
if series.empty:
return f"❌ No numeric data in '{column}'."
if kde:
# density + hist using numpy histogram
hist, edges = np.histogram(series, bins=bins)
fig = go.Figure()
fig.add_bar(x=edges[:-1], y=hist, name="Histogram")
fig.add_scatter(
x=np.linspace(series.min(), series.max(), 500),
y=np.exp(np.poly1d(np.polyfit(series, np.log(series.rank()), 1))(
np.linspace(series.min(), series.max(), 500)
)),
mode="lines",
name="KDE (approx)",
)
else:
fig = px.histogram(
series, nbins=bins, title=f"Histogram – {column}", template="plotly_dark"
)
fig.update_layout(template="plotly_dark")
return fig, _save_fig(fig, f"hist_{column}_", output_dir)
# -----------------------------------------------------------------
# 2) Box plot
# -----------------------------------------------------------------
def boxplot_tool(
file_path: str,
column: str,
output_dir: str = "/tmp",
) -> Union[Tuple[Plot, str], str]:
ext = os.path.splitext(file_path)[1].lower()
df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
if column not in df.columns:
return f"❌ Column '{column}' not found."
series = pd.to_numeric(df[column], errors="coerce").dropna()
if series.empty:
return f"❌ No numeric data in '{column}'."
fig = px.box(
series, points="outliers", title=f"Boxplot – {column}", template="plotly_dark"
)
return fig, _save_fig(fig, f"box_{column}_", output_dir)
# -----------------------------------------------------------------
# 3) Violin plot
# -----------------------------------------------------------------
def violin_tool(
file_path: str,
column: str,
output_dir: str = "/tmp",
) -> Union[Tuple[Plot, str], str]:
ext = os.path.splitext(file_path)[1].lower()
df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
if column not in df.columns:
return f"❌ Column '{column}' not found."
series = pd.to_numeric(df[column], errors="coerce").dropna()
if series.empty:
return f"❌ No numeric data in '{column}'."
fig = px.violin(
series, box=True, points="all", title=f"Violin – {column}", template="plotly_dark"
)
return fig, _save_fig(fig, f"violin_{column}_", output_dir)
# -----------------------------------------------------------------
# 4) Scatter‑matrix
# -----------------------------------------------------------------
def scatter_matrix_tool(
file_path: str,
columns: List[str],
output_dir: str = "/tmp",
size: int = 5,
) -> Union[Tuple[Plot, str], str]:
ext = os.path.splitext(file_path)[1].lower()
df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
missing = [c for c in columns if c not in df.columns]
if missing:
return f"❌ Missing columns: {', '.join(missing)}"
df_num = df[columns].apply(pd.to_numeric, errors="coerce").dropna()
if df_num.empty:
return "❌ No valid numeric data."
fig = px.scatter_matrix(
df_num, dimensions=columns, title="Scatter Matrix", template="plotly_dark"
)
fig.update_traces(diagonal_visible=False, marker=dict(size=size))
return fig, _save_fig(fig, "scatter_matrix_", output_dir)
# -----------------------------------------------------------------
# 5) Correlation heat‑map (optional clustering)
# -----------------------------------------------------------------
def corr_heatmap_tool(
file_path: str,
columns: List[str] | None = None,
output_dir: str = "/tmp",
cluster: bool = True,
) -> Union[Tuple[Plot, str], str]:
ext = os.path.splitext(file_path)[1].lower()
df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
df_num = df.select_dtypes("number") if columns is None else df[columns]
df_num = df_num.apply(pd.to_numeric, errors="coerce").dropna(axis=1, how="all")
if df_num.shape[1] < 2:
return "❌ Need ≥ 2 numeric columns."
corr = df_num.corr()
if cluster:
order = leaves_list(linkage(corr, "average"))
corr = corr.iloc[order, order]
fig = px.imshow(
corr,
color_continuous_scale="RdBu",
title="Correlation Heat‑map",
labels=dict(color="ρ"),
template="plotly_dark",
)
return fig, _save_fig(fig, "corr_heatmap_", output_dir)