titanhacker commited on
Commit
6a9bd56
·
verified ·
1 Parent(s): 7e38b85

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.utils.data as data
7
+ import gradio as gr
8
+ import plotly.graph_objects as go
9
+
10
+ # Function to create dataset for time series prediction
11
+ def create_dataset(dataset, lookback):
12
+ X, y = [], []
13
+ for i in range(len(dataset) - lookback):
14
+ feature = dataset[i:i + lookback]
15
+ target = dataset[i + 1:i + lookback + 1]
16
+ X.append(feature)
17
+ y.append(target)
18
+ X = np.array(X).reshape(-1, lookback, 1) # Reshape to 3D (samples, lookback, features)
19
+ y = np.array(y).reshape(-1, lookback, 1) # Reshape to 3D (samples, lookback, features)
20
+ return torch.tensor(X).float(), torch.tensor(y).float()
21
+
22
+ # Define LSTM model
23
+ class AirModel(nn.Module):
24
+ def __init__(self):
25
+ super(AirModel, self).__init__()
26
+ self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=1, batch_first=True)
27
+ self.linear = nn.Linear(50, 1)
28
+
29
+ def forward(self, x):
30
+ x, _ = self.lstm(x)
31
+ x = self.linear(x)
32
+ return x
33
+
34
+ # Training and prediction function
35
+ def train_and_predict(csv_file, lookback, epochs, batch_size):
36
+ # Load CSV
37
+ df = pd.read_csv(csv_file.name)
38
+
39
+ # Extract time series data
40
+ timeseries = df[["AmtNet Sales USD"]].values.astype('float32')
41
+
42
+ # Train-test split
43
+ train_size = int(len(timeseries) * 0.67)
44
+ test_size = len(timeseries) - train_size
45
+ train, test = timeseries[:train_size], timeseries[train_size:]
46
+
47
+ # Create datasets
48
+ X_train, y_train = create_dataset(train, lookback=lookback)
49
+ X_test, y_test = create_dataset(test, lookback=lookback)
50
+
51
+ if len(X_train) == 0 or len(X_test) == 0:
52
+ return "The lookback value is too large for the dataset. Please reduce the lookback value."
53
+
54
+ # DataLoader for batching
55
+ train_loader = data.DataLoader(data.TensorDataset(X_train, y_train), shuffle=True, batch_size=batch_size)
56
+
57
+ # Initialize model, optimizer, and loss function
58
+ model = AirModel()
59
+ optimizer = optim.Adam(model.parameters())
60
+ loss_fn = nn.MSELoss()
61
+
62
+ # Training loop
63
+ for epoch in range(epochs):
64
+ model.train()
65
+ for X_batch, y_batch in train_loader:
66
+ y_pred = model(X_batch)
67
+ loss = loss_fn(y_pred, y_batch)
68
+ optimizer.zero_grad()
69
+ loss.backward()
70
+ optimizer.step()
71
+
72
+ # Prediction
73
+ model.eval()
74
+ with torch.no_grad():
75
+ train_plot = np.ones_like(timeseries) * np.nan
76
+ train_plot[lookback:train_size] = model(X_train)[:, -1, :].numpy()
77
+
78
+ test_plot = np.ones_like(timeseries) * np.nan
79
+ test_plot[train_size + lookback:len(timeseries)] = model(X_test)[:, -1, :].numpy()
80
+
81
+ # Plot results with Plotly
82
+ fig = go.Figure()
83
+ fig.add_trace(go.Scatter(y=timeseries.squeeze(), mode='lines', name='Original Data'))
84
+ fig.add_trace(go.Scatter(y=train_plot.squeeze(), mode='lines', name='Train Prediction', line=dict(color='red')))
85
+ fig.add_trace(go.Scatter(y=test_plot.squeeze(), mode='lines', name='Test Prediction', line=dict(color='green')))
86
+ fig.update_layout(title="Time Series Prediction", xaxis_title="Time", yaxis_title="Sales")
87
+
88
+ # Calculate Mean Absolute Error (MAE)
89
+ mae = np.mean(np.abs(test_plot[train_size + lookback:len(timeseries)] - timeseries[train_size + lookback:len(timeseries)]))
90
+
91
+ return fig, f"Mean Absolute Error (MAE) on Test Data: {mae:.4f}"
92
+
93
+ # Gradio app interface using new API
94
+ interface = gr.Interface(
95
+ fn=train_and_predict,
96
+ inputs=[
97
+ gr.File(label="Upload your CSV file"),
98
+ gr.Slider(10, 365, step=1, value=100, label="Lookback"),
99
+ gr.Slider(100, 5000, step=100, value=1000, label="Epochs"),
100
+ gr.Slider(4, 32, step=1, value=8, label="Batch size")
101
+ ],
102
+ outputs=[
103
+ gr.Plot(label="Prediction Plot"),
104
+ gr.Textbox(label="Error Metrics")
105
+ ],
106
+ title="Time Series Prediction with LSTM",
107
+ description="Upload a CSV file with a 'Amount Net Sales' column and get time series predictions using an LSTM model.",
108
+ )
109
+
110
+ # Launch the app with a shareable link
111
+ interface.launch()