Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -46,13 +46,13 @@ model = LSTM(INPUT_SIZE, HIDDEN_LAYER_SIZE, OUTPUT_SIZE)
|
|
46 |
|
47 |
model.load_state_dict(torch.load(model_path))
|
48 |
|
49 |
-
model.eval()
|
50 |
|
51 |
def predict_and_plot(week4, week3, week2, week1):
|
52 |
last_four_weeks = [week4, week3, week2, week1]
|
53 |
|
54 |
custom_input = prepare_custom_input(last_four_weeks, seq_length, scaler)
|
55 |
|
|
|
56 |
with torch.no_grad():
|
57 |
model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
|
58 |
torch.zeros(1, 1, model.hidden_layer_size))
|
@@ -63,7 +63,7 @@ def predict_and_plot(week4, week3, week2, week1):
|
|
63 |
predicted_value = f"{predicted_kilometers[0][0]:.2f}"
|
64 |
|
65 |
weeks = ['Week -4', 'Week -3', 'Week -2', 'Week -1', 'Predicted Week']
|
66 |
-
actual_values =
|
67 |
|
68 |
plt.figure(figsize=(10, 6))
|
69 |
plt.plot(weeks, actual_values, marker='o', label='Total Kilometers')
|
@@ -72,6 +72,7 @@ def predict_and_plot(week4, week3, week2, week1):
|
|
72 |
plt.xlabel('Weeks')
|
73 |
plt.ylabel('Total Kilometers')
|
74 |
plt.grid()
|
|
|
75 |
plt.legend()
|
76 |
|
77 |
# Saving the plot to a BytesIO object
|
|
|
46 |
|
47 |
model.load_state_dict(torch.load(model_path))
|
48 |
|
|
|
49 |
|
50 |
def predict_and_plot(week4, week3, week2, week1):
|
51 |
last_four_weeks = [week4, week3, week2, week1]
|
52 |
|
53 |
custom_input = prepare_custom_input(last_four_weeks, seq_length, scaler)
|
54 |
|
55 |
+
model.eval()
|
56 |
with torch.no_grad():
|
57 |
model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
|
58 |
torch.zeros(1, 1, model.hidden_layer_size))
|
|
|
63 |
predicted_value = f"{predicted_kilometers[0][0]:.2f}"
|
64 |
|
65 |
weeks = ['Week -4', 'Week -3', 'Week -2', 'Week -1', 'Predicted Week']
|
66 |
+
actual_values = last_four_weeks + [predicted_value]
|
67 |
|
68 |
plt.figure(figsize=(10, 6))
|
69 |
plt.plot(weeks, actual_values, marker='o', label='Total Kilometers')
|
|
|
72 |
plt.xlabel('Weeks')
|
73 |
plt.ylabel('Total Kilometers')
|
74 |
plt.grid()
|
75 |
+
plt.tight_layout()
|
76 |
plt.legend()
|
77 |
|
78 |
# Saving the plot to a BytesIO object
|