Kashif Rasul commited on
Commit
46a14b8
·
1 Parent(s): 46dbe9e

added forecaster

Browse files
Files changed (2) hide show
  1. app.py +32 -11
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,19 +1,40 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
 
4
- def forecast(data):
5
- return data
 
 
6
 
 
 
 
 
 
 
 
7
 
8
- demo = gr.Interface(
9
- forecast,
10
- [
11
- gr.Timeseries(),
12
- ],
13
- [
14
- gr.Timeseries(),
15
- ],
16
- )
 
 
 
 
 
 
 
17
 
18
  if __name__ == "__main__":
19
  demo.launch()
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ from gluonts.dataset.pandas import PandasDataset
4
+ from gluonts.dataset.split import split
5
+ from gluonts.torch.model.deepar import DeepAREstimator
6
+ import matplotlib.pyplot as plt
7
 
8
 
9
+ def fn(upload_data):
10
+ df = pd.read_csv(upload_data.name, index_col=0, parse_dates=True)
11
+ dataset = PandasDataset(df, target=df.columns[0])
12
+ training_data, test_gen = split(dataset, offset=-36)
13
 
14
+ model = DeepAREstimator(
15
+ prediction_length=12,
16
+ freq=dataset.freq,
17
+ trainer_kwargs=dict(max_epochs=1),
18
+ ).train(
19
+ training_data=training_data,
20
+ )
21
 
22
+ test_data = test_gen.generate_instances(prediction_length=12, windows=3)
23
+ forecasts = list(model.predict(test_data.input))
24
+
25
+ fig = plt.figure()
26
+ df["#Passengers"].plot(color="black")
27
+ for forecast, color in zip(forecasts, ["green", "blue", "purple"]):
28
+ forecast.plot(color=f"tab:{color}")
29
+ plt.legend(["True values"], loc="upper left", fontsize="xx-large")
30
+ return fig
31
+
32
+
33
+ with gr.Blocks() as demo:
34
+ plot = gr.Plot()
35
+ upload_btn = gr.UploadButton()
36
+
37
+ upload_btn.upload(fn, inputs=upload_btn, outputs=plot)
38
 
39
  if __name__ == "__main__":
40
  demo.launch()
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  gluonts[torch,pro]
2
  pandas
3
-
 
1
  gluonts[torch,pro]
2
  pandas
3
+ matplotlib