Spaces:
Sleeping
Sleeping
File size: 3,469 Bytes
fe3e959 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from typing import List
import numpy as np
import pandas as pd
import plotly.graph_objects as go
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
"""
Plot the training and test datasets using Plotly.
Args:
df1 (pd.DataFrame): Train dataset
df2 (pd.DataFrame): Test dataset
Returns:
None
"""
# Create a Plotly figure
fig = go.Figure()
# Add the first scatter plot with steelblue color
fig.add_trace(
go.Scatter(
x=df1.index,
y=df1.iloc[:, 0],
mode="lines",
name="Training Data",
line=dict(color="steelblue"),
marker=dict(color="steelblue"),
)
)
# Add the second scatter plot with yellow color
fig.add_trace(
go.Scatter(
x=df2.index,
y=df2.iloc[:, 0],
mode="lines",
name="Test Data",
line=dict(color="gold"),
marker=dict(color="gold"),
)
)
# Customize the layout
fig.update_layout(
title="Univariate Time Series",
xaxis=dict(title="Date"),
yaxis=dict(title="Value"),
showlegend=True,
template="plotly_white",
)
return fig
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]):
"""
Plot the true values and forecasts using Plotly.
Args:
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.
Returns:
go.Figure: Plotly figure object.
"""
# Create a Plotly figure
fig = go.Figure()
# Add the true values trace
fig.add_trace(
go.Scatter(
x=pd.to_datetime(df.index),
y=df.iloc[:, 0],
mode="lines",
name="True values",
line=dict(color="black"),
)
)
# Add the forecast traces
colors = ["green", "blue", "purple"]
for i, forecast in enumerate(forecasts):
color = colors[i % len(colors)]
for sample in forecast.samples:
fig.add_trace(
go.Scatter(
x=forecast.index.to_timestamp(),
y=sample,
mode="lines",
opacity=0.15, # Adjust opacity to control visibility of individual samples
name=f"Forecast {i + 1}",
showlegend=False, # Hide the individual forecast series from the legend
hoverinfo="none", # Disable hover information for the forecast series
line=dict(color=color),
)
)
# Add the average
mean_forecast = np.mean(forecast.samples, axis=0)
fig.add_trace(
go.Scatter(
x=forecast.index.to_timestamp(),
y=mean_forecast,
mode="lines",
name="Mean Forecast",
line=dict(color="red", dash="dash"),
legendgroup="mean forecast",
showlegend=i == 0,
)
)
# Customize the layout
fig.update_layout(
title=f"{df.columns[0]} Forecast",
yaxis=dict(title=df.columns[0]),
showlegend=True,
legend=dict(x=0, y=1),
hovermode="x", # Enable x-axis hover for better interactivity
)
# Return the figure
return fig
|