jvedsaqib commited on
Commit
8610244
·
verified ·
1 Parent(s): bce373d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -46,13 +46,13 @@ model = LSTM(INPUT_SIZE, HIDDEN_LAYER_SIZE, OUTPUT_SIZE)
46
 
47
  model.load_state_dict(torch.load(model_path))
48
 
49
- model.eval()
50
 
51
  def predict_and_plot(week4, week3, week2, week1):
52
  last_four_weeks = [week4, week3, week2, week1]
53
 
54
  custom_input = prepare_custom_input(last_four_weeks, seq_length, scaler)
55
 
 
56
  with torch.no_grad():
57
  model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
58
  torch.zeros(1, 1, model.hidden_layer_size))
@@ -63,7 +63,7 @@ def predict_and_plot(week4, week3, week2, week1):
63
  predicted_value = f"{predicted_kilometers[0][0]:.2f}"
64
 
65
  weeks = ['Week -4', 'Week -3', 'Week -2', 'Week -1', 'Predicted Week']
66
- actual_values = [week4, week3, week2, week1, predicted_value]
67
 
68
  plt.figure(figsize=(10, 6))
69
  plt.plot(weeks, actual_values, marker='o', label='Total Kilometers')
@@ -72,6 +72,7 @@ def predict_and_plot(week4, week3, week2, week1):
72
  plt.xlabel('Weeks')
73
  plt.ylabel('Total Kilometers')
74
  plt.grid()
 
75
  plt.legend()
76
 
77
  # Saving the plot to a BytesIO object
 
46
 
47
  model.load_state_dict(torch.load(model_path))
48
 
 
49
 
50
  def predict_and_plot(week4, week3, week2, week1):
51
  last_four_weeks = [week4, week3, week2, week1]
52
 
53
  custom_input = prepare_custom_input(last_four_weeks, seq_length, scaler)
54
 
55
+ model.eval()
56
  with torch.no_grad():
57
  model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),
58
  torch.zeros(1, 1, model.hidden_layer_size))
 
63
  predicted_value = f"{predicted_kilometers[0][0]:.2f}"
64
 
65
  weeks = ['Week -4', 'Week -3', 'Week -2', 'Week -1', 'Predicted Week']
66
+ actual_values = last_four_weeks + [predicted_value]
67
 
68
  plt.figure(figsize=(10, 6))
69
  plt.plot(weeks, actual_values, marker='o', label='Total Kilometers')
 
72
  plt.xlabel('Weeks')
73
  plt.ylabel('Total Kilometers')
74
  plt.grid()
75
+ plt.tight_layout()
76
  plt.legend()
77
 
78
  # Saving the plot to a BytesIO object