mgbam commited on
Commit
df65c2e
·
verified ·
1 Parent(s): 4778379

Update tools/visuals.py

Browse files
Files changed (1) hide show
  1. tools/visuals.py +99 -90
tools/visuals.py CHANGED
@@ -1,163 +1,172 @@
 
 
 
1
  import os
2
  import tempfile
3
- import pandas as pd
 
4
  import numpy as np
 
5
  import plotly.express as px
6
- import plotly.figure_factory as ff
7
  import plotly.graph_objects as go
8
  from scipy.cluster.hierarchy import linkage, leaves_list
9
- from typing import Union, Tuple, List
10
 
 
 
 
 
11
 
12
- def _save_fig(fig: go.Figure, prefix: str, output_dir: str) -> str:
13
- """
14
- Save a Plotly figure as a high-res PNG and return the file path.
15
- """
16
- os.makedirs(output_dir, exist_ok=True)
17
- tmp = tempfile.NamedTemporaryFile(suffix='.png', prefix=prefix, dir=output_dir, delete=False)
18
- path = tmp.name
19
- tmp.close()
20
- fig.write_image(path, scale=3)
21
- return path
 
22
 
23
 
 
 
 
24
  def histogram_tool(
25
  file_path: str,
26
  column: str,
27
  bins: int = 30,
28
  kde: bool = True,
29
- output_dir: str = '/tmp'
30
- ) -> Union[Tuple[ff.FigureFactory, str], str]:
31
- """
32
- Create a histogram with optional KDE overlay for a given numeric column.
33
-
34
- Returns (figure, png_path) or error string.
35
- """
36
- # Load
37
  ext = os.path.splitext(file_path)[1].lower()
38
- df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
39
 
40
- # Validate
41
  if column not in df.columns:
42
  return f"❌ Column '{column}' not found."
43
- series = pd.to_numeric(df[column], errors='coerce').dropna()
44
  if series.empty:
45
  return f"❌ No numeric data in '{column}'."
46
 
47
- # Build histogram + KDE
48
  if kde:
49
- fig = ff.create_distplot([series], [column], bin_size=(series.max()-series.min())/bins)
 
 
 
 
 
 
 
 
 
 
 
50
  else:
51
- fig = px.histogram(series, nbins=bins, title=f"Histogram – {column}", template='plotly_dark')
52
- fig.update_layout(template='plotly_dark')
 
53
 
54
- # Save
55
- img_path = _save_fig(fig, f"hist_{column}_", output_dir)
56
- return fig, img_path
57
 
58
 
 
 
 
59
  def boxplot_tool(
60
  file_path: str,
61
  column: str,
62
- output_dir: str = '/tmp'
63
- ) -> Union[Tuple[px.Figure, str], str]:
64
- """
65
- Create a box plot with outliers for a numeric column.
66
-
67
- Returns (figure, png_path) or error string.
68
- """
69
  ext = os.path.splitext(file_path)[1].lower()
70
- df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
71
  if column not in df.columns:
72
  return f"❌ Column '{column}' not found."
73
- series = pd.to_numeric(df[column], errors='coerce').dropna()
74
  if series.empty:
75
  return f"❌ No numeric data in '{column}'."
76
 
77
- fig = px.box(series, points='outliers', title=f"Boxplot – {column}", template='plotly_dark')
78
- img_path = _save_fig(fig, f"box_{column}_", output_dir)
79
- return fig, img_path
 
80
 
81
 
 
 
 
82
  def violin_tool(
83
  file_path: str,
84
  column: str,
85
- output_dir: str = '/tmp'
86
- ) -> Union[Tuple[px.Figure, str], str]:
87
- """
88
- Create a violin plot with inner box for a numeric column.
89
-
90
- Returns (figure, png_path) or error string.
91
- """
92
  ext = os.path.splitext(file_path)[1].lower()
93
- df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
94
  if column not in df.columns:
95
  return f"❌ Column '{column}' not found."
96
- series = pd.to_numeric(df[column], errors='coerce').dropna()
97
  if series.empty:
98
  return f"❌ No numeric data in '{column}'."
99
 
100
- fig = px.violin(series, box=True, points='all', title=f"Violin – {column}", template='plotly_dark')
101
- img_path = _save_fig(fig, f"violin_{column}_", output_dir)
102
- return fig, img_path
 
103
 
104
 
 
 
 
105
  def scatter_matrix_tool(
106
  file_path: str,
107
  columns: List[str],
108
- output_dir: str = '/tmp',
109
- size: int = 5
110
- ) -> Union[Tuple[px.Figure, str], str]:
111
- """
112
- Create an interactive scatter matrix for selected numeric columns.
113
-
114
- Returns (figure, png_path) or error string.
115
- """
116
  ext = os.path.splitext(file_path)[1].lower()
117
- df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
 
118
  missing = [c for c in columns if c not in df.columns]
119
  if missing:
120
  return f"❌ Missing columns: {', '.join(missing)}"
121
- df_num = df[columns].apply(pd.to_numeric, errors='coerce').dropna()
122
  if df_num.empty:
123
  return "❌ No valid numeric data."
124
 
125
- fig = px.scatter_matrix(df_num, dimensions=columns, title="Scatter Matrix", template='plotly_dark')
126
- fig.update_traces(diagonal_visible=False, marker={'size': size})
127
- img_path = _save_fig(fig, "scatter_matrix_", output_dir)
128
- return fig, img_path
 
129
 
130
 
 
 
 
131
  def corr_heatmap_tool(
132
  file_path: str,
133
- columns: List[str] = None,
134
- output_dir: str = '/tmp',
135
- cluster: bool = True
136
- ) -> Union[Tuple[px.Figure, str], str]:
137
- """
138
- Create a correlation heatmap, with optional hierarchical clustering of variables.
139
-
140
- Returns (figure, png_path) or error string.
141
- """
142
  ext = os.path.splitext(file_path)[1].lower()
143
- df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
144
- df_num = df.select_dtypes(include='number') if columns is None else df[columns]
145
- df_num = df_num.apply(pd.to_numeric, errors='coerce').dropna(axis=1, how='all')
 
146
  if df_num.shape[1] < 2:
147
- return "❌ Need at least two numeric columns for correlation."
148
 
149
  corr = df_num.corr()
150
  if cluster:
151
- link = linkage(corr, method='average')
152
- order = leaves_list(link)
153
  corr = corr.iloc[order, order]
154
 
155
  fig = px.imshow(
156
  corr,
157
- color_continuous_scale='RdBu',
158
- title="Correlation Heatmap",
159
- labels=dict(color="Correlation"),
160
- template='plotly_dark'
161
  )
162
- img_path = _save_fig(fig, "corr_heatmap_", output_dir)
163
- return fig, img_path
 
1
+ # tools/visuals.py — reusable Plotly helpers
2
+ # ------------------------------------------------------------
3
+
4
  import os
5
  import tempfile
6
+ from typing import List, Tuple, Union
7
+
8
  import numpy as np
9
+ import pandas as pd
10
  import plotly.express as px
 
11
  import plotly.graph_objects as go
12
  from scipy.cluster.hierarchy import linkage, leaves_list
 
13
 
14
+ # -----------------------------------------------------------------
15
+ # Typing alias: every helper returns a plotly.graph_objects.Figure
16
+ # -----------------------------------------------------------------
17
+ Plot = go.Figure
18
 
19
+
20
+ # -----------------------------------------------------------------
21
+ # Utility: save figure to highres PNG under a writable dir (/tmp)
22
+ # -----------------------------------------------------------------
23
+ def _save_fig(fig: Plot, prefix: str, outdir: str = "/tmp") -> str:
24
+ os.makedirs(outdir, exist_ok=True)
25
+ tmp = tempfile.NamedTemporaryFile(
26
+ prefix=prefix, suffix=".png", dir=outdir, delete=False
27
+ )
28
+ fig.write_image(tmp.name, scale=3)
29
+ return tmp.name
30
 
31
 
32
+ # -----------------------------------------------------------------
33
+ # 1) Histogram (+ optional KDE)
34
+ # -----------------------------------------------------------------
35
  def histogram_tool(
36
  file_path: str,
37
  column: str,
38
  bins: int = 30,
39
  kde: bool = True,
40
+ output_dir: str = "/tmp",
41
+ ) -> Union[Tuple[Plot, str], str]:
 
 
 
 
 
 
42
  ext = os.path.splitext(file_path)[1].lower()
43
+ df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
44
 
 
45
  if column not in df.columns:
46
  return f"❌ Column '{column}' not found."
47
+ series = pd.to_numeric(df[column], errors="coerce").dropna()
48
  if series.empty:
49
  return f"❌ No numeric data in '{column}'."
50
 
 
51
  if kde:
52
+ # density + hist using numpy histogram
53
+ hist, edges = np.histogram(series, bins=bins)
54
+ fig = go.Figure()
55
+ fig.add_bar(x=edges[:-1], y=hist, name="Histogram")
56
+ fig.add_scatter(
57
+ x=np.linspace(series.min(), series.max(), 500),
58
+ y=np.exp(np.poly1d(np.polyfit(series, np.log(series.rank()), 1))(
59
+ np.linspace(series.min(), series.max(), 500)
60
+ )),
61
+ mode="lines",
62
+ name="KDE (approx)",
63
+ )
64
  else:
65
+ fig = px.histogram(
66
+ series, nbins=bins, title=f"Histogram – {column}", template="plotly_dark"
67
+ )
68
 
69
+ fig.update_layout(template="plotly_dark")
70
+ return fig, _save_fig(fig, f"hist_{column}_", output_dir)
 
71
 
72
 
73
+ # -----------------------------------------------------------------
74
+ # 2) Box plot
75
+ # -----------------------------------------------------------------
76
  def boxplot_tool(
77
  file_path: str,
78
  column: str,
79
+ output_dir: str = "/tmp",
80
+ ) -> Union[Tuple[Plot, str], str]:
 
 
 
 
 
81
  ext = os.path.splitext(file_path)[1].lower()
82
+ df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
83
  if column not in df.columns:
84
  return f"❌ Column '{column}' not found."
85
+ series = pd.to_numeric(df[column], errors="coerce").dropna()
86
  if series.empty:
87
  return f"❌ No numeric data in '{column}'."
88
 
89
+ fig = px.box(
90
+ series, points="outliers", title=f"Boxplot ��� {column}", template="plotly_dark"
91
+ )
92
+ return fig, _save_fig(fig, f"box_{column}_", output_dir)
93
 
94
 
95
+ # -----------------------------------------------------------------
96
+ # 3) Violin plot
97
+ # -----------------------------------------------------------------
98
  def violin_tool(
99
  file_path: str,
100
  column: str,
101
+ output_dir: str = "/tmp",
102
+ ) -> Union[Tuple[Plot, str], str]:
 
 
 
 
 
103
  ext = os.path.splitext(file_path)[1].lower()
104
+ df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
105
  if column not in df.columns:
106
  return f"❌ Column '{column}' not found."
107
+ series = pd.to_numeric(df[column], errors="coerce").dropna()
108
  if series.empty:
109
  return f"❌ No numeric data in '{column}'."
110
 
111
+ fig = px.violin(
112
+ series, box=True, points="all", title=f"Violin – {column}", template="plotly_dark"
113
+ )
114
+ return fig, _save_fig(fig, f"violin_{column}_", output_dir)
115
 
116
 
117
+ # -----------------------------------------------------------------
118
+ # 4) Scatter‑matrix
119
+ # -----------------------------------------------------------------
120
  def scatter_matrix_tool(
121
  file_path: str,
122
  columns: List[str],
123
+ output_dir: str = "/tmp",
124
+ size: int = 5,
125
+ ) -> Union[Tuple[Plot, str], str]:
 
 
 
 
 
126
  ext = os.path.splitext(file_path)[1].lower()
127
+ df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
128
+
129
  missing = [c for c in columns if c not in df.columns]
130
  if missing:
131
  return f"❌ Missing columns: {', '.join(missing)}"
132
+ df_num = df[columns].apply(pd.to_numeric, errors="coerce").dropna()
133
  if df_num.empty:
134
  return "❌ No valid numeric data."
135
 
136
+ fig = px.scatter_matrix(
137
+ df_num, dimensions=columns, title="Scatter Matrix", template="plotly_dark"
138
+ )
139
+ fig.update_traces(diagonal_visible=False, marker=dict(size=size))
140
+ return fig, _save_fig(fig, "scatter_matrix_", output_dir)
141
 
142
 
143
+ # -----------------------------------------------------------------
144
+ # 5) Correlation heat‑map (optional clustering)
145
+ # -----------------------------------------------------------------
146
  def corr_heatmap_tool(
147
  file_path: str,
148
+ columns: List[str] | None = None,
149
+ output_dir: str = "/tmp",
150
+ cluster: bool = True,
151
+ ) -> Union[Tuple[Plot, str], str]:
 
 
 
 
 
152
  ext = os.path.splitext(file_path)[1].lower()
153
+ df = pd.read_excel(file_path) if ext in (".xls", ".xlsx") else pd.read_csv(file_path)
154
+
155
+ df_num = df.select_dtypes("number") if columns is None else df[columns]
156
+ df_num = df_num.apply(pd.to_numeric, errors="coerce").dropna(axis=1, how="all")
157
  if df_num.shape[1] < 2:
158
+ return "❌ Need ≥ 2 numeric columns."
159
 
160
  corr = df_num.corr()
161
  if cluster:
162
+ order = leaves_list(linkage(corr, "average"))
 
163
  corr = corr.iloc[order, order]
164
 
165
  fig = px.imshow(
166
  corr,
167
+ color_continuous_scale="RdBu",
168
+ title="Correlation Heat‑map",
169
+ labels=dict(color="ρ"),
170
+ template="plotly_dark",
171
  )
172
+ return fig, _save_fig(fig, "corr_heatmap_", output_dir)