asataura commited on
Commit
6256826
·
1 Parent(s): 1c2ad5f

Updating DDQN.py

Browse files
Files changed (1) hide show
  1. DDQN.py +5 -4
DDQN.py CHANGED
@@ -68,12 +68,13 @@ class DoubleDeepQNetwork:
68
  # Convert to numpy for speed by vectorization
69
  x = []
70
  y = []
71
- np_array = np.array(minibatch)
72
  st = np.zeros((0, self.nS)) # States
73
  nst = np.zeros((0, self.nS)) # Next States
74
- for i in range(len(np_array)): # Creating the state and next state np arrays
75
- st = np.append(st, np_array[i, 0], axis=0)
76
- nst = np.append(nst, np_array[i, 3], axis=0)
 
77
  st_predict = self.model.predict(st) # Here is the speedup! I can predict on the ENTIRE batch
78
  nst_predict = self.model.predict(nst)
79
  nst_predict_target = self.model_target.predict(nst) # Predict from the TARGET
 
68
  # Convert to numpy for speed by vectorization
69
  x = []
70
  y = []
71
+ np_array = list(minibatch)
72
  st = np.zeros((0, self.nS)) # States
73
  nst = np.zeros((0, self.nS)) # Next States
74
+ for i in range(len(np_array)):
75
+ st = np.append(st, np_array[i][0], axis=0)
76
+ nst = np.append(nst, np_array[i][3], axis=0)
77
+
78
  st_predict = self.model.predict(st) # Here is the speedup! I can predict on the ENTIRE batch
79
  nst_predict = self.model.predict(nst)
80
  nst_predict_target = self.model_target.predict(nst) # Predict from the TARGET