Spaces:
Sleeping
Sleeping
# python test_chore_phosa.py --dataset_root /is/cluster/fast/achatterjee/CHORE_DECO/phosa_test --npz_file /is/cluster/fast/achatterjee/Datasets/hot_phosa/hot_phosa_test.npz --json_file /ps/scratch/ps_shared/stripathi/deco/4agniv/hot_phosa_split/imgnames_per_object_dict.json --th 0.05 | |
import numpy as np | |
import pandas as pd | |
import torch | |
import os | |
import os.path as osp | |
import json | |
import trimesh | |
import argparse | |
import collections | |
from tqdm import tqdm | |
DIST_MATRIX = np.load('/is/cluster/fast/achatterjee/dca_contact/data/smpl/smpl_neutral_geodesic_dist.npy') | |
def precision_recall_f1score(gt, pred): | |
""" | |
Compute precision, recall, and f1 | |
""" | |
precision = np.zeros(gt.shape[0]) | |
recall = np.zeros(gt.shape[0]) | |
f1 = np.zeros(gt.shape[0]) | |
for b in range(gt.shape[0]): | |
tp_num = gt[b, pred[b, :] >= 0.5].sum() | |
precision_denominator = (pred[b, :] >= 0.5).sum() | |
recall_denominator = (gt[b, :]).sum() | |
precision_ = tp_num / precision_denominator | |
recall_ = tp_num / recall_denominator | |
if precision_denominator == 0: # if no pred | |
precision_ = 1. | |
recall_ = 0. | |
f1_ = 0. | |
elif recall_denominator == 0: # if no GT | |
precision_ = 0. | |
recall_ = 1. | |
f1_ = 0. | |
elif (precision_ + recall_) <= 1e-10: # to avoid precision issues | |
precision_= 0. | |
recall_= 0. | |
f1_ = 0. | |
else: | |
f1_ = 2 * precision_ * recall_ / (precision_ + recall_) | |
precision[b] = precision_ | |
recall[b] = recall_ | |
f1[b] = f1_ | |
return precision, recall, f1 | |
def det_error_metric(gt, pred): | |
dist_matrix = torch.tensor(DIST_MATRIX) | |
false_positive_dist = torch.zeros(gt.shape[0]) | |
false_negative_dist = torch.zeros(gt.shape[0]) | |
for b in range(gt.shape[0]): | |
gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix | |
error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns | |
false_positive_dist_ = error_matrix.min(dim=1)[0].mean() | |
false_negative_dist_ = error_matrix.min(dim=0)[0].mean() | |
false_positive_dist[b] = false_positive_dist_ | |
false_negative_dist[b] = false_negative_dist_ | |
return false_positive_dist, false_negative_dist | |
def main(args): | |
with open(args.json_file, 'r') as fp: | |
img_list_dict = json.load(fp) | |
d = np.load(args.npz_file) | |
img_count = 0 | |
tot_pre = 0 | |
tot_rec = 0 | |
tot_f1 = 0 | |
tot_fp_err = 0 | |
# for i, img in tqdm(enumerate(d['imgname']), dynamic_ncols=True): | |
# gt = d['contact_label'][i] | |
# gt = np.reshape(gt, (1, -1)) | |
# imgname = osp.basename(img) | |
# for k, v in img_list_dict.items(): | |
# if imgname in v: | |
# object_class = k | |
# break | |
# # pred_path = osp.join(args.dataset_root, object_class, imgname, 'chore-test', f'k1.smpl_colored_{args.th}.ply') | |
# pred_path = osp.join(args.dataset_root, object_class, imgname, f'k1.smpl_colored_{args.th}.obj') | |
# if not osp.exists(pred_path): | |
# # print(pred_path) | |
# continue | |
# pred_mesh = trimesh.load(pred_path, process=False) | |
# pred = np.zeros((1, 6890)) | |
# for i in range(pred_mesh.visual.vertex_colors.shape[0]): | |
# c = pred_mesh.visual.vertex_colors[i] | |
# if collections.Counter(c) == collections.Counter([255, 0, 0, 255]): | |
# pred[0, i] = 1 | |
# pre, rec, f1 = precision_recall_f1score(gt, pred) | |
# fp_err, _ = det_error_metric(pred, gt) | |
# tot_pre += pre.sum() | |
# tot_rec += rec.sum() | |
# tot_f1 += f1.sum() | |
# tot_fp_err += fp_err.numpy().sum() | |
# img_count += 1 | |
# print(f'Dataset size: {img_count}') | |
# print(f'Threshold: {args.th}\n') | |
# print(f'Test Precision: {tot_pre/img_count}') | |
# print(f'Test Recall: {tot_rec/img_count}') | |
# print(f'Test F1 Score: {tot_f1/img_count}') | |
# print(f'Test FP Error: {tot_fp_err/img_count}') | |
# Part of code for checking a random image | |
img_search = osp.join('/ps/project/datasets/HOT/Contact_Data/images/training', 'vcoco_000000542163.jpg') | |
i = np.where(d['imgname'] == img_search)[0][0] | |
print(i) | |
# i = 50 | |
img = d['imgname'][i] | |
gt = d['contact_label'][i] | |
gt = np.reshape(gt, (1, -1)) | |
imgname = osp.basename(img) | |
print(f'Image: {imgname}') | |
for k, v in img_list_dict.items(): | |
if imgname in v: | |
object_class = k | |
break | |
print(f'Object: {object_class}') | |
# pred_path = osp.join(args.dataset_root, object_class, imgname, 'chore-test', f'k1.smpl_colored_{args.th}.ply') | |
pred_path = osp.join(args.dataset_root, object_class, imgname, f'k1.smpl_colored_{args.th}.obj') | |
if not osp.exists(pred_path): | |
print(f'Missing file: {pred_path}') | |
pred_mesh = trimesh.load(pred_path, process=False) | |
pred = np.zeros((1, 6890)) | |
for i in range(pred_mesh.visual.vertex_colors.shape[0]): | |
c = pred_mesh.visual.vertex_colors[i] | |
if collections.Counter(c) == collections.Counter([255, 0, 0, 255]): | |
pred[0, i] = 1 | |
pre, rec, f1 = precision_recall_f1score(gt, pred) | |
fp_err, _ = det_error_metric(pred, gt) | |
tot_pre += pre.sum() | |
tot_rec += rec.sum() | |
tot_f1 += f1.sum() | |
tot_fp_err += fp_err.numpy().sum() | |
print(f'Test Precision: {tot_pre}') | |
print(f'Test Recall: {tot_rec}') | |
print(f'Test F1 Score: {tot_f1}') | |
print(f'Test FP Error: {tot_fp_err}') | |
# best_pre = 0 | |
# best_rec = 0 | |
# best_f1 = 0 | |
# best_fp_err = 0 | |
# best_imgname = '' | |
# best_obj = '' | |
# for i, img in tqdm(enumerate(d['imgname']), dynamic_ncols=True): | |
# gt = d['contact_label'][i] | |
# gt = np.reshape(gt, (1, -1)) | |
# imgname = osp.basename(img) | |
# for k, v in img_list_dict.items(): | |
# if imgname in v: | |
# object_class = k | |
# break | |
# # pred_path = osp.join(args.dataset_root, object_class, imgname, 'chore-test', f'k1.smpl_colored_{args.th}.ply') | |
# pred_path = osp.join(args.dataset_root, object_class, imgname, f'k1.smpl_colored_{args.th}.obj') | |
# if not osp.exists(pred_path): | |
# # print(pred_path) | |
# continue | |
# pred_mesh = trimesh.load(pred_path, process=False) | |
# pred = np.zeros((1, 6890)) | |
# for i in range(pred_mesh.visual.vertex_colors.shape[0]): | |
# c = pred_mesh.visual.vertex_colors[i] | |
# if collections.Counter(c) == collections.Counter([255, 0, 0, 255]): | |
# pred[0, i] = 1 | |
# pre, rec, f1 = precision_recall_f1score(gt, pred) | |
# fp_err, _ = det_error_metric(pred, gt) | |
# tot_pre += pre.sum() | |
# tot_rec += rec.sum() | |
# tot_f1 += f1.sum() | |
# tot_fp_err += fp_err.numpy().sum() | |
# if f1.sum() > best_f1: | |
# best_pre = pre.sum() | |
# best_rec = rec.sum() | |
# best_f1 = f1.sum() | |
# best_fp_err = fp_err.numpy().sum() | |
# best_imgname = imgname | |
# best_obj = object_class | |
# img_count += 1 | |
# print(f'Dataset size: {img_count}') | |
# print(f'Threshold: {args.th}\n') | |
# print(f'Test Precision: {tot_pre/img_count}') | |
# print(f'Test Recall: {tot_rec/img_count}') | |
# print(f'Test F1 Score: {tot_f1/img_count}') | |
# print(f'Test FP Error: {tot_fp_err/img_count}\n') | |
# print(f'Best Precision: {best_pre}') | |
# print(f'Best Recall: {best_rec}') | |
# print(f'Best F1 Score: {best_f1}') | |
# print(f'Best FP Error: {best_fp_err}') | |
# print(f'Best Image: {best_imgname} and Object: {best_obj}') | |
if __name__=='__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset_root', type=str, default='/is/cluster/fast/achatterjee/CHORE_DECO/hot_test') | |
parser.add_argument('--npz_file', type=str, default='/is/cluster/fast/achatterjee/Datasets/hot_behave/hot_behave_test_reduced.npz') | |
parser.add_argument('--json_file', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot_behave_split/imgnames_per_object_dict_reduced.json') | |
parser.add_argument('--th', type=float) | |
args = parser.parse_args() | |
main(args) | |