mgbam commited on
Commit
9538f35
Β·
verified Β·
1 Parent(s): dc5ae18

Update tools/visuals.py

Browse files
Files changed (1) hide show
  1. tools/visuals.py +102 -71
tools/visuals.py CHANGED
@@ -1,132 +1,163 @@
1
  import os
2
  import tempfile
3
  import pandas as pd
 
4
  import plotly.express as px
5
- from typing import Union, Tuple
 
 
 
6
 
7
- def _save_fig(fig, prefix: str, output_dir: str) -> str:
 
8
  """
9
- Save a Plotly figure as a PNG to a temp file and return its path.
10
  """
11
  os.makedirs(output_dir, exist_ok=True)
12
  tmp = tempfile.NamedTemporaryFile(suffix='.png', prefix=prefix, dir=output_dir, delete=False)
13
  path = tmp.name
14
  tmp.close()
15
- try:
16
- fig.write_image(path, scale=2)
17
- except Exception as e:
18
- raise
19
  return path
20
 
21
 
22
  def histogram_tool(
23
  file_path: str,
24
  column: str,
25
- output_dir: str = '/tmp',
26
- bins: int = 30
27
- ) -> Union[Tuple[px.Figure, str], str]:
 
28
  """
29
- Build a histogram for a numeric column, return a Plotly Figure and PNG path,
30
- or an error string starting with '❌'.
 
31
  """
32
- # Load data
33
  ext = os.path.splitext(file_path)[1].lower()
34
- try:
35
- df = pd.read_excel(file_path) if ext in ('.xls', '.xlsx') else pd.read_csv(file_path)
36
- except Exception as exc:
37
- return f"❌ Failed to load file: {exc}"
38
 
39
- # Validate column
40
  if column not in df.columns:
41
  return f"❌ Column '{column}' not found."
42
-
43
- # Coerce to numeric
44
- df[column] = pd.to_numeric(df[column], errors='coerce')
45
- series = df[column].dropna()
46
  if series.empty:
47
- return f"❌ No valid numeric data in '{column}'."
48
-
49
- # Create figure
50
- fig = px.histogram(
51
- df,
52
- x=column,
53
- nbins=bins,
54
- title=f"Histogram – {column}",
55
- template='plotly_dark'
56
- )
57
- # Save PNG
58
  img_path = _save_fig(fig, f"hist_{column}_", output_dir)
59
  return fig, img_path
60
 
61
 
62
- def scatter_matrix_tool(
63
  file_path: str,
64
- cols: list[str],
65
  output_dir: str = '/tmp'
66
  ) -> Union[Tuple[px.Figure, str], str]:
67
  """
68
- Build a scatter-matrix for selected numeric columns, return figure and PNG path,
69
- or an error string starting with '❌'.
 
70
  """
71
- # Load data
72
  ext = os.path.splitext(file_path)[1].lower()
73
- try:
74
- df = pd.read_excel(file_path) if ext in ('.xls', '.xlsx') else pd.read_csv(file_path)
75
- except Exception as exc:
76
- return f"❌ Failed to load file: {exc}"
 
 
 
 
 
 
77
 
78
- # Validate columns
79
- missing = [c for c in cols if c not in df.columns]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if missing:
81
  return f"❌ Missing columns: {', '.join(missing)}"
82
-
83
- # Filter numeric
84
- df_num = df[cols].apply(pd.to_numeric, errors='coerce').dropna()
85
  if df_num.empty:
86
- return f"❌ No valid numeric data in selected columns."
87
 
88
- # Create figure
89
- fig = px.scatter_matrix(
90
- df_num,
91
- dimensions=cols,
92
- title="Scatter-Matrix",
93
- template='plotly_dark'
94
- )
95
- # Save PNG
96
  img_path = _save_fig(fig, "scatter_matrix_", output_dir)
97
  return fig, img_path
98
 
99
 
100
  def corr_heatmap_tool(
101
  file_path: str,
 
102
  output_dir: str = '/tmp',
103
- color_continuous_scale: str = 'RdBu'
104
  ) -> Union[Tuple[px.Figure, str], str]:
105
  """
106
- Build a correlation heatmap for numeric columns, return figure and PNG path,
107
- or an error string starting with '❌'.
 
108
  """
109
- # Load data
110
  ext = os.path.splitext(file_path)[1].lower()
111
- try:
112
- df = pd.read_excel(file_path) if ext in ('.xls', '.xlsx') else pd.read_csv(file_path)
113
- except Exception as exc:
114
- return f"❌ Failed to load file: {exc}"
 
115
 
116
- # Compute correlation
117
- df_num = df.select_dtypes(include='number').apply(pd.to_numeric, errors='coerce')
118
- if df_num.empty:
119
- return "❌ No numeric columns available for correlation."
120
  corr = df_num.corr()
 
 
 
 
121
 
122
- # Create figure
123
  fig = px.imshow(
124
  corr,
125
- color_continuous_scale=color_continuous_scale,
126
  title="Correlation Heatmap",
127
  labels=dict(color="Correlation"),
128
  template='plotly_dark'
129
  )
130
- # Save PNG
131
  img_path = _save_fig(fig, "corr_heatmap_", output_dir)
132
  return fig, img_path
 
1
  import os
2
  import tempfile
3
  import pandas as pd
4
+ import numpy as np
5
  import plotly.express as px
6
+ import plotly.figure_factory as ff
7
+ import plotly.graph_objects as go
8
+ from scipy.cluster.hierarchy import linkage, leaves_list
9
+ from typing import Union, Tuple, List
10
 
11
+
12
+ def _save_fig(fig: go.Figure, prefix: str, output_dir: str) -> str:
13
  """
14
+ Save a Plotly figure as a high-res PNG and return the file path.
15
  """
16
  os.makedirs(output_dir, exist_ok=True)
17
  tmp = tempfile.NamedTemporaryFile(suffix='.png', prefix=prefix, dir=output_dir, delete=False)
18
  path = tmp.name
19
  tmp.close()
20
+ fig.write_image(path, scale=3)
 
 
 
21
  return path
22
 
23
 
24
  def histogram_tool(
25
  file_path: str,
26
  column: str,
27
+ bins: int = 30,
28
+ kde: bool = True,
29
+ output_dir: str = '/tmp'
30
+ ) -> Union[Tuple[ff.FigureFactory, str], str]:
31
  """
32
+ Create a histogram with optional KDE overlay for a given numeric column.
33
+
34
+ Returns (figure, png_path) or error string.
35
  """
36
+ # Load
37
  ext = os.path.splitext(file_path)[1].lower()
38
+ df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
 
 
 
39
 
40
+ # Validate
41
  if column not in df.columns:
42
  return f"❌ Column '{column}' not found."
43
+ series = pd.to_numeric(df[column], errors='coerce').dropna()
 
 
 
44
  if series.empty:
45
+ return f"❌ No numeric data in '{column}'."
46
+
47
+ # Build histogram + KDE
48
+ if kde:
49
+ fig = ff.create_distplot([series], [column], bin_size=(series.max()-series.min())/bins)
50
+ else:
51
+ fig = px.histogram(series, nbins=bins, title=f"Histogram – {column}", template='plotly_dark')
52
+ fig.update_layout(template='plotly_dark')
53
+
54
+ # Save
 
55
  img_path = _save_fig(fig, f"hist_{column}_", output_dir)
56
  return fig, img_path
57
 
58
 
59
+ def boxplot_tool(
60
  file_path: str,
61
+ column: str,
62
  output_dir: str = '/tmp'
63
  ) -> Union[Tuple[px.Figure, str], str]:
64
  """
65
+ Create a box plot with outliers for a numeric column.
66
+
67
+ Returns (figure, png_path) or error string.
68
  """
 
69
  ext = os.path.splitext(file_path)[1].lower()
70
+ df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
71
+ if column not in df.columns:
72
+ return f"❌ Column '{column}' not found."
73
+ series = pd.to_numeric(df[column], errors='coerce').dropna()
74
+ if series.empty:
75
+ return f"❌ No numeric data in '{column}'."
76
+
77
+ fig = px.box(series, points='outliers', title=f"Boxplot – {column}", template='plotly_dark')
78
+ img_path = _save_fig(fig, f"box_{column}_", output_dir)
79
+ return fig, img_path
80
 
81
+
82
+ def violin_tool(
83
+ file_path: str,
84
+ column: str,
85
+ output_dir: str = '/tmp'
86
+ ) -> Union[Tuple[px.Figure, str], str]:
87
+ """
88
+ Create a violin plot with inner box for a numeric column.
89
+
90
+ Returns (figure, png_path) or error string.
91
+ """
92
+ ext = os.path.splitext(file_path)[1].lower()
93
+ df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
94
+ if column not in df.columns:
95
+ return f"❌ Column '{column}' not found."
96
+ series = pd.to_numeric(df[column], errors='coerce').dropna()
97
+ if series.empty:
98
+ return f"❌ No numeric data in '{column}'."
99
+
100
+ fig = px.violin(series, box=True, points='all', title=f"Violin – {column}", template='plotly_dark')
101
+ img_path = _save_fig(fig, f"violin_{column}_", output_dir)
102
+ return fig, img_path
103
+
104
+
105
+ def scatter_matrix_tool(
106
+ file_path: str,
107
+ columns: List[str],
108
+ output_dir: str = '/tmp',
109
+ size: int = 5
110
+ ) -> Union[Tuple[px.Figure, str], str]:
111
+ """
112
+ Create an interactive scatter matrix for selected numeric columns.
113
+
114
+ Returns (figure, png_path) or error string.
115
+ """
116
+ ext = os.path.splitext(file_path)[1].lower()
117
+ df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
118
+ missing = [c for c in columns if c not in df.columns]
119
  if missing:
120
  return f"❌ Missing columns: {', '.join(missing)}"
121
+ df_num = df[columns].apply(pd.to_numeric, errors='coerce').dropna()
 
 
122
  if df_num.empty:
123
+ return "❌ No valid numeric data."
124
 
125
+ fig = px.scatter_matrix(df_num, dimensions=columns, title="Scatter Matrix", template='plotly_dark')
126
+ fig.update_traces(diagonal_visible=False, marker={'size': size})
 
 
 
 
 
 
127
  img_path = _save_fig(fig, "scatter_matrix_", output_dir)
128
  return fig, img_path
129
 
130
 
131
  def corr_heatmap_tool(
132
  file_path: str,
133
+ columns: List[str] = None,
134
  output_dir: str = '/tmp',
135
+ cluster: bool = True
136
  ) -> Union[Tuple[px.Figure, str], str]:
137
  """
138
+ Create a correlation heatmap, with optional hierarchical clustering of variables.
139
+
140
+ Returns (figure, png_path) or error string.
141
  """
 
142
  ext = os.path.splitext(file_path)[1].lower()
143
+ df = pd.read_excel(file_path) if ext in ('.xls','.xlsx') else pd.read_csv(file_path)
144
+ df_num = df.select_dtypes(include='number') if columns is None else df[columns]
145
+ df_num = df_num.apply(pd.to_numeric, errors='coerce').dropna(axis=1, how='all')
146
+ if df_num.shape[1] < 2:
147
+ return "❌ Need at least two numeric columns for correlation."
148
 
 
 
 
 
149
  corr = df_num.corr()
150
+ if cluster:
151
+ link = linkage(corr, method='average')
152
+ order = leaves_list(link)
153
+ corr = corr.iloc[order, order]
154
 
 
155
  fig = px.imshow(
156
  corr,
157
+ color_continuous_scale='RdBu',
158
  title="Correlation Heatmap",
159
  labels=dict(color="Correlation"),
160
  template='plotly_dark'
161
  )
 
162
  img_path = _save_fig(fig, "corr_heatmap_", output_dir)
163
  return fig, img_path