mgbam commited on
Commit
3651f7b
Β·
verified Β·
1 Parent(s): b04cfbb

Update tools/forecaster.py

Browse files
Files changed (1) hide show
  1. tools/forecaster.py +12 -30
tools/forecaster.py CHANGED
@@ -1,50 +1,32 @@
1
  import pandas as pd
2
- import plotly.graph_objects as go
3
  from statsmodels.tsa.arima.model import ARIMA
 
4
 
5
- def forecast_tool(file_path: str, date_col: str) -> str:
6
  """
7
- Forecast next 3 periods of 'Sales'. Returns text summary and saves forecast_plot.png.
 
8
  """
9
  df = pd.read_csv(file_path)
10
 
11
  try:
12
  df[date_col] = pd.to_datetime(df[date_col])
13
  except Exception:
14
- return f"❌ Column '{date_col}' cannot be parsed as dates."
15
 
16
- if "Sales" not in df.columns:
17
- return "❌ CSV must contain a 'Sales' column."
18
 
19
  df.set_index(date_col, inplace=True)
20
- model = ARIMA(df["Sales"], order=(1, 1, 1))
21
  model_fit = model.fit()
22
  forecast = model_fit.forecast(steps=3)
23
 
24
- # Interactive Plotly forecast with confidence interval
25
- conf_int = model_fit.get_forecast(steps=3).conf_int()
26
- future_index = forecast.index
27
-
28
  fig = go.Figure()
29
- fig.add_scatter(x=df.index, y=df["Sales"], mode="lines", name="Sales")
30
- fig.add_scatter(x=future_index, y=forecast, mode="lines", name="Forecast")
31
- fig.add_scatter(
32
- x=future_index,
33
- y=conf_int.iloc[:, 0],
34
- mode="lines",
35
- fill=None,
36
- line=dict(width=0),
37
- showlegend=False,
38
- )
39
- fig.add_scatter(
40
- x=future_index,
41
- y=conf_int.iloc[:, 1],
42
- mode="lines",
43
- fill="tonexty",
44
- name="95% CI",
45
- line=dict(width=0),
46
- )
47
- fig.update_layout(title="Sales Forecast", template="plotly_dark")
48
  fig.write_image("forecast_plot.png")
49
 
50
  return forecast.to_frame(name="Forecast").to_string()
 
1
  import pandas as pd
 
2
  from statsmodels.tsa.arima.model import ARIMA
3
+ import plotly.graph_objects as go
4
 
5
+ def forecast_metric_tool(file_path: str, date_col: str, value_col: str):
6
  """
7
+ Forecast next 3 periods for any numeric metric.
8
+ Saves PNG and returns forecast DataFrame as text.
9
  """
10
  df = pd.read_csv(file_path)
11
 
12
  try:
13
  df[date_col] = pd.to_datetime(df[date_col])
14
  except Exception:
15
+ return f"❌ '{date_col}' not parseable as dates."
16
 
17
+ if value_col not in df.columns:
18
+ return f"❌ '{value_col}' column missing."
19
 
20
  df.set_index(date_col, inplace=True)
21
+ model = ARIMA(df[value_col], order=(1, 1, 1))
22
  model_fit = model.fit()
23
  forecast = model_fit.forecast(steps=3)
24
 
25
+ # Plot
 
 
 
26
  fig = go.Figure()
27
+ fig.add_scatter(x=df.index, y=df[value_col], mode="lines", name=value_col)
28
+ fig.add_scatter(x=forecast.index, y=forecast, mode="lines", name="Forecast")
29
+ fig.update_layout(title=f"{value_col} Forecast", template="plotly_dark")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  fig.write_image("forecast_plot.png")
31
 
32
  return forecast.to_frame(name="Forecast").to_string()