BizIntel_AI / tools /forecaster.py
mgbam's picture
Update tools/forecaster.py
a9cd5f0 verified
raw
history blame
1.56 kB
import pandas as pd
import plotly.graph_objects as go
from statsmodels.tsa.arima.model import ARIMA
def forecast_tool(file_path: str, date_col: str) -> str:
"""
Forecast next 3 periods of 'Sales'. Returns text summary and saves forecast_plot.png.
"""
df = pd.read_csv(file_path)
try:
df[date_col] = pd.to_datetime(df[date_col])
except Exception:
return f"❌ Column '{date_col}' cannot be parsed as dates."
if "Sales" not in df.columns:
return "❌ CSV must contain a 'Sales' column."
df.set_index(date_col, inplace=True)
model = ARIMA(df["Sales"], order=(1, 1, 1))
model_fit = model.fit()
forecast = model_fit.forecast(steps=3)
# Interactive Plotly forecast with confidence interval
conf_int = model_fit.get_forecast(steps=3).conf_int()
future_index = forecast.index
fig = go.Figure()
fig.add_scatter(x=df.index, y=df["Sales"], mode="lines", name="Sales")
fig.add_scatter(x=future_index, y=forecast, mode="lines", name="Forecast")
fig.add_scatter(
x=future_index,
y=conf_int.iloc[:, 0],
mode="lines",
fill=None,
line=dict(width=0),
showlegend=False,
)
fig.add_scatter(
x=future_index,
y=conf_int.iloc[:, 1],
mode="lines",
fill="tonexty",
name="95% CI",
line=dict(width=0),
)
fig.update_layout(title="Sales Forecast", template="plotly_dark")
fig.write_image("forecast_plot.png")
return forecast.to_frame(name="Forecast").to_string()