Andrei Cozma commited on
Commit
f52cc9a
·
1 Parent(s): 8298563
Files changed (1) hide show
  1. demo.py +22 -18
demo.py CHANGED
@@ -178,40 +178,44 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
178
  1.0,
179
  )
180
 
181
- text_color = frame_policy[frame_policy_h // 2, int((action + 0.5) * frame_policy_res // len(curr_policy))]
182
- text_color = 1.0 - text_color
 
 
 
 
 
 
183
 
184
  cv2.putText(
185
  frame_policy,
186
  str(action),
187
  (
188
- int((action + 0.5) * frame_policy_res // len(curr_policy) - 8),
189
- frame_policy_h // 2 - 5,
190
  ),
191
- cv2.FONT_HERSHEY_SIMPLEX,
192
- 0.8,
193
- text_color,
194
- 1,
195
  cv2.LINE_AA,
196
  )
197
 
198
  if env_action_map:
199
  action_name = env_action_map.get(action, "")
200
-
 
201
  cv2.putText(
202
  frame_policy,
203
  action_name,
204
  (
205
- int(
206
- (action + 0.5) * frame_policy_res // len(curr_policy)
207
- - 5 * len(action_name)
208
- ),
209
- frame_policy_h // 2 + 25,
210
  ),
211
- cv2.FONT_HERSHEY_SIMPLEX,
212
- 0.5,
213
- text_color,
214
- 1,
215
  cv2.LINE_AA,
216
  )
217
 
 
178
  1.0,
179
  )
180
 
181
+ label_loc_h, label_loc_w =frame_policy_h // 2, int((action + 0.5) * frame_policy_res // len(curr_policy))
182
+
183
+ frame_policy_label_color = 1.0 - frame_policy[label_loc_h, label_loc_w]
184
+ frame_policy_label_font = cv2.FONT_HERSHEY_SIMPLEX
185
+ frame_policy_label_thicc = 1
186
+ action_text_scale, action_text_label_scale = 0.8, 0.5
187
+
188
+ (label_width, _), _ = cv2.getTextSize(str(action), frame_policy_label_font, action_text_scale, frame_policy_label_thicc)
189
 
190
  cv2.putText(
191
  frame_policy,
192
  str(action),
193
  (
194
+ label_loc_w - label_width // 2,
195
+ label_loc_h,
196
  ),
197
+ frame_policy_label_font,
198
+ action_text_scale,
199
+ frame_policy_label_color,
200
+ frame_policy_label_thicc,
201
  cv2.LINE_AA,
202
  )
203
 
204
  if env_action_map:
205
  action_name = env_action_map.get(action, "")
206
+ (label_width, _), _ = cv2.getTextSize(action_name, frame_policy_label_font, action_text_label_scale, frame_policy_label_thicc)
207
+
208
  cv2.putText(
209
  frame_policy,
210
  action_name,
211
  (
212
+ int(label_loc_w - label_width / 2),
213
+ label_loc_h + 25,
 
 
 
214
  ),
215
+ frame_policy_label_font,
216
+ action_text_label_scale,
217
+ frame_policy_label_color,
218
+ frame_policy_label_thicc,
219
  cv2.LINE_AA,
220
  )
221