mgbam commited on
Commit
a9cd5f0
·
verified ·
1 Parent(s): bf400de

Update tools/forecaster.py

Browse files
Files changed (1) hide show
  1. tools/forecaster.py +30 -25
tools/forecaster.py CHANGED
@@ -1,30 +1,13 @@
1
  import pandas as pd
2
- import matplotlib.pyplot as plt
3
  from statsmodels.tsa.arima.model import ARIMA
4
 
5
- def forecast_tool(file_path: str, date_col: str | None = None) -> str:
6
  """
7
- Forecast the next 3 periods of the 'Sales' column.
8
- • If date_col is provided, use it.
9
- • Otherwise auto‑detect the first column that can be parsed as dates.
10
-
11
- Returns human‑readable summary and saves 'forecast_plot.png'.
12
  """
13
  df = pd.read_csv(file_path)
14
 
15
- # Auto‑detect date column if not specified
16
- if date_col is None:
17
- for col in df.columns:
18
- try:
19
- pd.to_datetime(df[col])
20
- date_col = col
21
- break
22
- except Exception:
23
- continue
24
- if date_col is None:
25
- return "❌ No parseable date column found."
26
-
27
- # Parse the date column
28
  try:
29
  df[date_col] = pd.to_datetime(df[date_col])
30
  except Exception:
@@ -38,8 +21,30 @@ def forecast_tool(file_path: str, date_col: str | None = None) -> str:
38
  model_fit = model.fit()
39
  forecast = model_fit.forecast(steps=3)
40
 
41
- forecast_df = pd.DataFrame(forecast, columns=["Forecast"])
42
- forecast_df.plot(title="Sales Forecast", figsize=(10, 6))
43
- plt.savefig("forecast_plot.png")
44
-
45
- return forecast_df.to_string()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
+ import plotly.graph_objects as go
3
  from statsmodels.tsa.arima.model import ARIMA
4
 
5
+ def forecast_tool(file_path: str, date_col: str) -> str:
6
  """
7
+ Forecast next 3 periods of 'Sales'. Returns text summary and saves forecast_plot.png.
 
 
 
 
8
  """
9
  df = pd.read_csv(file_path)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  try:
12
  df[date_col] = pd.to_datetime(df[date_col])
13
  except Exception:
 
21
  model_fit = model.fit()
22
  forecast = model_fit.forecast(steps=3)
23
 
24
+ # Interactive Plotly forecast with confidence interval
25
+ conf_int = model_fit.get_forecast(steps=3).conf_int()
26
+ future_index = forecast.index
27
+
28
+ fig = go.Figure()
29
+ fig.add_scatter(x=df.index, y=df["Sales"], mode="lines", name="Sales")
30
+ fig.add_scatter(x=future_index, y=forecast, mode="lines", name="Forecast")
31
+ fig.add_scatter(
32
+ x=future_index,
33
+ y=conf_int.iloc[:, 0],
34
+ mode="lines",
35
+ fill=None,
36
+ line=dict(width=0),
37
+ showlegend=False,
38
+ )
39
+ fig.add_scatter(
40
+ x=future_index,
41
+ y=conf_int.iloc[:, 1],
42
+ mode="lines",
43
+ fill="tonexty",
44
+ name="95% CI",
45
+ line=dict(width=0),
46
+ )
47
+ fig.update_layout(title="Sales Forecast", template="plotly_dark")
48
+ fig.write_image("forecast_plot.png")
49
+
50
+ return forecast.to_frame(name="Forecast").to_string()