File size: 3,815 Bytes
660daa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Utils for visual iterative prompting.

A number of utility functions for VIP.
"""

import re

import matplotlib.pyplot as plt
import numpy as np
import scipy.spatial.distance as distance


def min_dist(coord, coords):
  if not coords:
    return np.inf
  xys = np.asarray([[coord.xy] for coord in coords])
  return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min()


def coord_outside_image(coord, image, radius):
  (height, image_width, _) = image.shape
  x, y = coord.xy
  x_outside = x > image_width - 2 * radius or x < 2 * radius
  y_outside = y > height - 2 * radius or y < 2 * radius
  return x_outside or y_outside


def is_invalid_coord(coord, coords, radius, image):
  # invalid if too close to others or outside of the image
  pos_overlaps = min_dist(coord, coords) < 1.5 * radius
  return pos_overlaps or coord_outside_image(coord, image, radius)


def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40):
  x, y = arm_coord
  x += int(np.cos(angle) * mag)
  y += int(np.sin(angle) * mag)
  if is_circle:
    x += int(np.cos(angle) * radius * np.sign(mag))
    y += int(np.sin(angle) * radius * np.sign(mag))
  return x, y


def coord_to_text_coord(coord, arm_coord, radius):
  delta_coord = np.asarray(coord.xy) - arm_coord
  if np.linalg.norm(delta_coord) == 0:
    return arm_coord
  return (
      int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
      int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)),
  )


def parse_response(response, answer_key='Arrow: ['):
  values = []
  if answer_key in response:
    print('parse_response from answer_key')
    arrow_response = response.split(answer_key)[-1].split(']')[0]
    for val in map(int, re.findall(r'\d+', arrow_response)):
      values.append(val)
  else:
    print('parse_response for all ints')
    for val in map(int, re.findall(r'\d+', response)):
      values.append(val)
  return values


def compute_errors(action, true_action, verbose=False):
  """Compute errors between a predicted action and true action."""
  l2_error = np.linalg.norm(action - true_action)
  cos_sim = 1 - distance.cosine(action, true_action)
  l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
  cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
  z_error = np.abs(action[0] - true_action[0])
  errors = {
      'l2': l2_error,
      'cos_sim': cos_sim,
      'l2_xy_error': l2_xy_error,
      'cos_xy_sim': cos_xy_sim,
      'z_error': z_error,
  }

  if verbose:
    print('action: \t', [f'{a:.3f}' for a in action])
    print('true_action \t', [f'{a:.3f}' for a in true_action])
    print(f'l2: \t\t{l2_error:.3f}')
    print(f'l2_xy_error: \t{l2_xy_error:.3f}')
    print(f'cos_sim: \t{cos_sim:.3f}')
    print(f'cos_xy_sim: \t{cos_xy_sim:.3f}')
    print(f'z_error: \t{z_error:.3f}')

  return errors


def plot_errors(all_errors, error_types=None):
  """Plot errors across iterations."""
  if error_types is None:
    error_types = [
        'l2',
        'l2_xy_error',
        'z_error',
        'cos_sim',
        'cos_xy_sim',
    ]

  _, axs = plt.subplots(2, 3, figsize=(15, 8))
  for i, error_type in enumerate(error_types):  # go through each error type
    all_iter_errors = {}
    for error_by_iter in all_errors:  # go through each call
      for itr in error_by_iter:  # go through each iteration
        if itr in all_iter_errors:  # add error to the iteration it happened
          all_iter_errors[itr].append(error_by_iter[itr][error_type])
        else:
          all_iter_errors[itr] = [error_by_iter[itr][error_type]]

    mean_iter_errors = [
        np.mean(all_iter_errors[itr]) for itr in all_iter_errors
    ]

    axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
    axs[i // 3, i % 3].set_title(error_type)
  plt.show()