|
from collections import defaultdict |
|
|
|
import numpy as np |
|
from prismatic.vla.action_tokenizer import ActionTokenizer |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class Solver: |
|
def __init__(self, action_tokenizer=None, verbose=True) -> None: |
|
self.verbose = verbose |
|
self.action_tokenizer = action_tokenizer |
|
self.coordinates_key = "NEXT GRIPPER:" |
|
self.movement_key = "MOVEMENT:" |
|
self.policy_key = "POLICIES:" |
|
|
|
def compare_movement(self, pred_pos, label_pos): |
|
|
|
dist = np.sum(np.abs(pred_pos - label_pos)) |
|
relative_dist = np.sum(np.abs(dist / label_pos)) |
|
return dist, relative_dist, dist == 0 |
|
|
|
def compare_policy(self, pred_pol, label_pol): |
|
dist = 0 |
|
cnt = 0 |
|
for i in range(min(len(label_pol), len(pred_pol))): |
|
for j in range(len(label_pol[0])): |
|
dist += label_pol[i][j] == pred_pol[i][j] |
|
cnt += 1 |
|
assert cnt % 7 == 0 |
|
return dist / cnt |
|
|
|
def extract_2d_coordinates(self, text): |
|
try: |
|
coordinates_index = text.index(self.coordinates_key) + len(self.coordinates_key) |
|
coord = text[coordinates_index:] |
|
coord = [o for o in coord.split("\n") if len(o.strip()) != 0] |
|
coord = eval(coord[0].strip()) |
|
except Exception: |
|
coord = [0, 0] |
|
return coord |
|
|
|
def extract_movement_plan(self, text): |
|
require_unorm = None |
|
try: |
|
|
|
movement_index = text.index(self.movement_key) + len(self.movement_key) |
|
movement_level = text[movement_index:] |
|
movement_level = [o for o in movement_level.split("\n") if len(o.strip()) != 0] |
|
movement_level = movement_level[0].strip() |
|
|
|
if "gripper" not in movement_level: |
|
require_unorm = True |
|
movement_token_ids = self.action_tokenizer.tokenizer(movement_level, add_special_tokens=False).input_ids |
|
movement_norm = self.action_tokenizer.decode_token_ids_to_actions(np.array(movement_token_ids)) |
|
movement_norm = movement_norm[1:8] |
|
assert len(movement_norm) == 7 |
|
else: |
|
require_unorm = False |
|
movement_level = [o for o in movement_level.split(";") if len(o) > 0] |
|
movement_level = movement_level[:7] |
|
|
|
position = defaultdict(int) |
|
movement_to_pos = dict( |
|
move_backward=(-1, "y"), |
|
move_forward=(1, "y"), |
|
move_right=(-1, "x"), |
|
move_left=(1, "x"), |
|
move_downward=(-1, "z"), |
|
move_upward=(1, "z"), |
|
roll_downward=(-1, "ox"), |
|
roll_upward=(1, "ox"), |
|
swing_downward=(-1, "ox"), |
|
swing_upward=(1, "ox"), |
|
pitch_downward=(-1, "oy"), |
|
pitch_upward=(1, "oy"), |
|
yaw_downward=(-1, "oz"), |
|
yaw_upward=(1, "oz"), |
|
rotate_clockwise=(-1, "oz"), |
|
rotate_counterclockwise=(1, "oz"), |
|
close_gripper=(-1, "grip"), |
|
open_gripper=(1, "grip"), |
|
) |
|
|
|
for ml in movement_level: |
|
direction = "_".join(ml.split()[:2]) |
|
sign, axis = movement_to_pos[direction] |
|
scale = 1 |
|
if "o" in axis: |
|
scale = scale * 1e-3 |
|
elif "grip" in axis: |
|
scale = scale |
|
else: |
|
scale = scale / 180 * np.pi |
|
|
|
if "grip" in axis: |
|
level = round("open" in ml) |
|
else: |
|
level = int(ml.split()[2]) |
|
|
|
position[axis] += sign * scale * level |
|
movement_norm = [position[idx] for idx in ["x", "y", "z", "ox", "oy", "oz", "grip"]] |
|
|
|
except: |
|
movement_norm = [-100] * 7 |
|
|
|
return require_unorm, np.array(movement_norm) |
|
|
|
def extract_action_policies(self, text): |
|
try: |
|
if self.policy_key in text: |
|
|
|
policy_index = text.index(self.policy_key) + len(self.policy_key) |
|
policy = text[policy_index:] |
|
remain_text = text[: text.index(self.policy_key)] |
|
policies = [o for o in policy.split("\n") if len(o.strip()) != 0] |
|
policies = policies[0].strip() |
|
else: |
|
policies = text.strip() |
|
remain_text = "" |
|
|
|
policies_num = [] |
|
for policy_text in policies.split(";"): |
|
policy_token = self.action_tokenizer.tokenizer(policy_text, add_special_tokens=False).input_ids |
|
action_policy = self.action_tokenizer.decode_token_ids_to_actions(np.array(policy_token)) |
|
|
|
action_policy = action_policy[1:] |
|
action_policy = action_policy[:7] |
|
|
|
if len(action_policy) != 7: |
|
action_policy = [0] * 7 |
|
policies_num.append(action_policy.tolist()) |
|
|
|
except: |
|
policies_num = [[0] * 7] |
|
remain_text = text |
|
|
|
return policies_num, remain_text |
|
|
|
def evaluate_single(self, ground_truth, prediction, verbose=False): |
|
gt_policies, ground_truth = self.extract_action_policies(ground_truth) |
|
pred_policies, prediction = self.extract_action_policies(prediction) |
|
|
|
_, pred_movement = self.extract_movement_plan(prediction) |
|
_, gt_movement = self.extract_movement_plan(ground_truth) |
|
|
|
dist, relative_dist, _ = self.compare_movement(label_pos=gt_movement, pred_pos=pred_movement) |
|
|
|
|
|
|
|
|
|
next_state_score = 0 |
|
|
|
acc = self.compare_policy(label_pol=gt_policies, pred_pol=pred_policies) |
|
|
|
return next_state_score, acc, dist, relative_dist, pred_policies, gt_policies |
|
|
|
def evaluate_batch(self, batch_gt, batch_pred, verbose=False): |
|
state_acc_ls = [] |
|
action_acc_ls = [] |
|
L1_loss_ls = [] |
|
relative_L1_loss_ls = [] |
|
pred_policies_ls = [] |
|
gt_policies_ls = [] |
|
for i in range(len(batch_gt)): |
|
ground_truth = batch_gt[i] |
|
prediction = batch_pred[i] |
|
next_state_score, action_policy_score, L1_dist, relative_L1_dist, pred_policies, gt_policies = ( |
|
self.evaluate_single(ground_truth, prediction) |
|
) |
|
state_acc_ls.append(next_state_score) |
|
action_acc_ls.append(action_policy_score) |
|
L1_loss_ls.append(L1_dist) |
|
relative_L1_loss_ls.append(relative_L1_dist) |
|
pred_policies_ls.append(pred_policies) |
|
gt_policies_ls.append(gt_policies) |
|
if verbose: |
|
print(f"Ground Truth:\n\n {ground_truth}") |
|
print() |
|
print(f"prediction:\n\n {prediction}") |
|
print() |
|
print(f"Ground Truth Policies:\n\n {gt_policies}") |
|
print(f"prediction policies:\n\n {pred_policies}") |
|
print("*" * 40) |
|
|
|
return state_acc_ls, action_acc_ls, L1_loss_ls, relative_L1_loss_ls, pred_policies_ls, gt_policies_ls |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", model_max_length=2048, padding_side="right") |
|
action_tokenizer = ActionTokenizer(tokenizer) |
|
solver = Solver(action_tokenizer) |
|
|
|
|