mgbam commited on
Commit
92cca14
Β·
verified Β·
1 Parent(s): de6f1e8

Update tools/forecaster.py

Browse files
Files changed (1) hide show
  1. tools/forecaster.py +44 -54
tools/forecaster.py CHANGED
@@ -1,8 +1,10 @@
 
1
  import os
2
  import tempfile
3
  import pandas as pd
4
  from statsmodels.tsa.arima.model import ARIMA
5
  import plotly.graph_objects as go
 
6
 
7
 
8
  def forecast_metric_tool(
@@ -11,44 +13,38 @@ def forecast_metric_tool(
11
  value_col: str,
12
  periods: int = 3,
13
  output_dir: str = "/tmp"
14
- ):
15
  """
16
- Load a CSV or Excel file, parse a time series metric, fit an ARIMA(1,1,1) model,
17
- forecast the next `periods` steps, and save a combined history+forecast plot.
18
 
19
  Returns:
20
- forecast_df (pd.DataFrame): next-period predicted values, indexed by date.
21
- plot_path (str): full path to the saved PNG plot.
22
-
23
- Errors return a string starting with '❌' describing the problem.
24
  """
25
- # 0) Load data (CSV or Excel)
26
  ext = os.path.splitext(file_path)[1].lower()
27
  try:
28
- if ext in ('.xls', '.xlsx'):
29
- df = pd.read_excel(file_path)
30
- else:
31
- df = pd.read_csv(file_path)
32
- except Exception as e:
33
- return f"❌ Failed to load file: {e}"
34
 
35
- # 1) Validate columns
36
- for col in (date_col, value_col):
37
- if col not in df.columns:
38
- return f"❌ Column '{col}' not found."
39
 
40
- # 2) Parse dates and numeric
41
  try:
42
- df[date_col] = pd.to_datetime(df[date_col])
43
  except Exception:
44
  return f"❌ Could not parse '{date_col}' as dates."
45
-
46
  df[value_col] = pd.to_numeric(df[value_col], errors='coerce')
47
  df = df.dropna(subset=[date_col, value_col])
48
  if df.empty:
49
- return f"❌ No valid rows after dropping NaNs in '{date_col}'/'{value_col}'."
50
 
51
- # 3) Aggregate duplicates & index
52
  df = (
53
  df[[date_col, value_col]]
54
  .groupby(date_col, as_index=True)
@@ -56,56 +52,50 @@ def forecast_metric_tool(
56
  .sort_index()
57
  )
58
 
59
- # 4) Infer frequency
60
- freq = pd.infer_freq(df.index)
61
- if freq is None:
62
- freq = 'D' # fallback
63
  try:
64
  df = df.asfreq(freq)
65
- except ValueError as e:
66
- # if duplicates remain
67
  df = df[~df.index.duplicated(keep='first')].asfreq(freq)
68
 
69
- # 5) Fit ARIMA
70
  try:
71
  model = ARIMA(df[value_col], order=(1, 1, 1))
72
  fit = model.fit()
73
- except Exception as e:
74
- return f"❌ ARIMA fitting failed: {e}"
75
 
76
- # 6) Forecast future
77
- fc_res = fit.get_forecast(steps=periods)
78
- forecast = fc_res.predicted_mean
 
 
 
79
  forecast_df = forecast.to_frame(name='Forecast')
80
 
81
- # 7) Plot history + forecast
82
- fig = go.Figure()
83
- fig.add_trace(
84
- go.Scatter(
85
- x=df.index, y=df[value_col],
86
- mode='lines+markers', name=value_col
87
- )
88
- )
89
- fig.add_trace(
90
- go.Scatter(
91
- x=forecast.index, y=forecast,
92
- mode='lines+markers', name='Forecast'
93
- )
94
  )
95
  fig.update_layout(
96
  title=f"{value_col} Forecast",
97
  xaxis_title=date_col,
98
  yaxis_title=value_col,
99
- template='plotly_dark',
100
  )
101
 
102
- # 8) Save to temporary file
103
  os.makedirs(output_dir, exist_ok=True)
104
- tmp = tempfile.NamedTemporaryFile(
105
- suffix='.png', prefix='forecast_', dir=output_dir, delete=False
106
- )
107
  plot_path = tmp.name
108
  tmp.close()
109
- fig.write_image(plot_path)
 
 
 
110
 
111
  return forecast_df, plot_path
 
1
+ # tools/forecaster.py
2
  import os
3
  import tempfile
4
  import pandas as pd
5
  from statsmodels.tsa.arima.model import ARIMA
6
  import plotly.graph_objects as go
7
+ from typing import Tuple, Union
8
 
9
 
10
  def forecast_metric_tool(
 
13
  value_col: str,
14
  periods: int = 3,
15
  output_dir: str = "/tmp"
16
+ ) -> Union[Tuple[pd.DataFrame, str], str]:
17
  """
18
+ Load CSV or Excel, parse a time series metric, fit ARIMA(1,1,1),
19
+ forecast next `periods` steps, return DataFrame and PNG path.
20
 
21
  Returns:
22
+ - (forecast_df, plot_path) on success
23
+ - error string starting with '❌' on failure
 
 
24
  """
25
+ # Load data
26
  ext = os.path.splitext(file_path)[1].lower()
27
  try:
28
+ df = pd.read_excel(file_path) if ext in ('.xls', '.xlsx') else pd.read_csv(file_path)
29
+ except Exception as exc:
30
+ return f"❌ Failed to load file: {exc}"
 
 
 
31
 
32
+ # Validate columns
33
+ missing = [c for c in (date_col, value_col) if c not in df.columns]
34
+ if missing:
35
+ return f"❌ Missing column(s): {', '.join(missing)}"
36
 
37
+ # Parse and clean
38
  try:
39
+ df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
40
  except Exception:
41
  return f"❌ Could not parse '{date_col}' as dates."
 
42
  df[value_col] = pd.to_numeric(df[value_col], errors='coerce')
43
  df = df.dropna(subset=[date_col, value_col])
44
  if df.empty:
45
+ return f"❌ No valid data after cleaning '{date_col}'/'{value_col}'"
46
 
47
+ # Aggregate duplicates and sort
48
  df = (
49
  df[[date_col, value_col]]
50
  .groupby(date_col, as_index=True)
 
52
  .sort_index()
53
  )
54
 
55
+ # Infer frequency
56
+ freq = pd.infer_freq(df.index) or 'D'
 
 
57
  try:
58
  df = df.asfreq(freq)
59
+ except Exception:
 
60
  df = df[~df.index.duplicated(keep='first')].asfreq(freq)
61
 
62
+ # Fit ARIMA
63
  try:
64
  model = ARIMA(df[value_col], order=(1, 1, 1))
65
  fit = model.fit()
66
+ except Exception as exc:
67
+ return f"❌ ARIMA fitting failed: {exc}"
68
 
69
+ # Forecast
70
+ try:
71
+ pred = fit.get_forecast(steps=periods)
72
+ forecast = pred.predicted_mean
73
+ except Exception as exc:
74
+ return f"❌ Forecast generation failed: {exc}"
75
  forecast_df = forecast.to_frame(name='Forecast')
76
 
77
+ # Plot history + forecast
78
+ fig = go.Figure(
79
+ data=[
80
+ go.Scatter(x=df.index, y=df[value_col], mode='lines', name='History'),
81
+ go.Scatter(x=forecast.index, y=forecast, mode='lines+markers', name='Forecast')
82
+ ]
 
 
 
 
 
 
 
83
  )
84
  fig.update_layout(
85
  title=f"{value_col} Forecast",
86
  xaxis_title=date_col,
87
  yaxis_title=value_col,
88
+ template='plotly_dark'
89
  )
90
 
91
+ # Save PNG
92
  os.makedirs(output_dir, exist_ok=True)
93
+ tmp = tempfile.NamedTemporaryFile(suffix='.png', prefix='forecast_', dir=output_dir, delete=False)
 
 
94
  plot_path = tmp.name
95
  tmp.close()
96
+ try:
97
+ fig.write_image(plot_path, scale=2)
98
+ except Exception as exc:
99
+ return f"❌ Plot saving failed: {exc}"
100
 
101
  return forecast_df, plot_path