BizIntel_AI / tools /forecaster.py
mgbam's picture
Update tools/forecaster.py
eec9db3 verified
raw
history blame
3.12 kB
import os
import tempfile
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
import plotly.graph_objects as go
def forecast_metric_tool(
file_path: str,
date_col: str,
value_col: str,
periods: int = 3,
output_dir: str = "/tmp"
):
"""
Load a CSV or Excel file, parse a time series metric, fit an ARIMA(1,1,1) model,
forecast the next `periods` steps, and save a combined history+forecast plot.
Returns:
forecast_df (pd.DataFrame): next-period predicted values, indexed by date.
plot_path (str): full path to the saved PNG plot.
Errors return a string starting with '❌' describing the problem.
"""
# 0) Load data (CSV or Excel)
ext = os.path.splitext(file_path)[1].lower()
try:
if ext in ('.xls', '.xlsx'):
df = pd.read_excel(file_path)
else:
df = pd.read_csv(file_path)
except Exception as e:
return f"❌ Failed to load file: {e}"
# 1) Validate columns
for col in (date_col, value_col):
if col not in df.columns:
return f"❌ Column '{col}' not found."
# 2) Parse dates and numeric
try:
df[date_col] = pd.to_datetime(df[date_col])
except Exception:
return f"❌ Could not parse '{date_col}' as dates."
df[value_col] = pd.to_numeric(df[value_col], errors='coerce')
df = df.dropna(subset=[date_col, value_col])
if df.empty:
return f"❌ No valid rows after dropping NaNs in '{date_col}'/'{value_col}'."
# 3) Aggregate duplicates & index
df = (
df[[date_col, value_col]]
.groupby(date_col, as_index=True)
.mean()
.sort_index()
)
# 4) Infer frequency
freq = pd.infer_freq(df.index)
if freq is None:
freq = 'D' # fallback
try:
df = df.asfreq(freq)
except ValueError as e:
# if duplicates remain
df = df[~df.index.duplicated(keep='first')].asfreq(freq)
# 5) Fit ARIMA
try:
model = ARIMA(df[value_col], order=(1, 1, 1))
fit = model.fit()
except Exception as e:
return f"❌ ARIMA fitting failed: {e}"
# 6) Forecast future
fc_res = fit.get_forecast(steps=periods)
forecast = fc_res.predicted_mean
forecast_df = forecast.to_frame(name='Forecast')
# 7) Plot history + forecast
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=df.index, y=df[value_col],
mode='lines+markers', name=value_col
)
)
fig.add_trace(
go.Scatter(
x=forecast.index, y=forecast,
mode='lines+markers', name='Forecast'
)
)
fig.update_layout(
title=f"{value_col} Forecast",
xaxis_title=date_col,
yaxis_title=value_col,
template='plotly_dark',
)
# 8) Save to temporary file
os.makedirs(output_dir, exist_ok=True)
tmp = tempfile.NamedTemporaryFile(
suffix='.png', prefix='forecast_', dir=output_dir, delete=False
)
plot_path = tmp.name
tmp.close()
fig.write_image(plot_path)
return forecast_df, plot_path