Spaces:
Sleeping
Sleeping
from typing import List | |
import numpy as np | |
import pandas as pd | |
import plotly.graph_objects as go | |
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure: | |
""" | |
Plot the training and test datasets using Plotly. | |
Args: | |
df1 (pd.DataFrame): Train dataset | |
df2 (pd.DataFrame): Test dataset | |
Returns: | |
None | |
""" | |
# Create a Plotly figure | |
fig = go.Figure() | |
# Add the first scatter plot with steelblue color | |
fig.add_trace( | |
go.Scatter( | |
x=df1.index, | |
y=df1.iloc[:, 0], | |
mode="lines", | |
name="Training Data", | |
line=dict(color="steelblue"), | |
marker=dict(color="steelblue"), | |
) | |
) | |
# Add the second scatter plot with yellow color | |
fig.add_trace( | |
go.Scatter( | |
x=df2.index, | |
y=df2.iloc[:, 0], | |
mode="lines", | |
name="Test Data", | |
line=dict(color="gold"), | |
marker=dict(color="gold"), | |
) | |
) | |
# Customize the layout | |
fig.update_layout( | |
title="Univariate Time Series", | |
xaxis=dict(title="Date"), | |
yaxis=dict(title="Value"), | |
showlegend=True, | |
template="plotly_white", | |
) | |
return fig | |
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]): | |
""" | |
Plot the true values and forecasts using Plotly. | |
Args: | |
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns. | |
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts. | |
Returns: | |
go.Figure: Plotly figure object. | |
""" | |
# Create a Plotly figure | |
fig = go.Figure() | |
# Add the true values trace | |
fig.add_trace( | |
go.Scatter( | |
x=pd.to_datetime(df.index), | |
y=df.iloc[:, 0], | |
mode="lines", | |
name="True values", | |
line=dict(color="black"), | |
) | |
) | |
# Add the forecast traces | |
colors = ["green", "blue", "purple"] | |
for i, forecast in enumerate(forecasts): | |
color = colors[i % len(colors)] | |
for sample in forecast.samples: | |
fig.add_trace( | |
go.Scatter( | |
x=forecast.index.to_timestamp(), | |
y=sample, | |
mode="lines", | |
opacity=0.15, # Adjust opacity to control visibility of individual samples | |
name=f"Forecast {i + 1}", | |
showlegend=False, # Hide the individual forecast series from the legend | |
hoverinfo="none", # Disable hover information for the forecast series | |
line=dict(color=color), | |
) | |
) | |
# Add the average | |
mean_forecast = np.mean(forecast.samples, axis=0) | |
fig.add_trace( | |
go.Scatter( | |
x=forecast.index.to_timestamp(), | |
y=mean_forecast, | |
mode="lines", | |
name="Mean Forecast", | |
line=dict(color="red", dash="dash"), | |
legendgroup="mean forecast", | |
showlegend=i == 0, | |
) | |
) | |
# Customize the layout | |
fig.update_layout( | |
title=f"{df.columns[0]} Forecast", | |
yaxis=dict(title=df.columns[0]), | |
showlegend=True, | |
legend=dict(x=0, y=1), | |
hovermode="x", # Enable x-axis hover for better interactivity | |
) | |
# Return the figure | |
return fig | |