DECO / scripts /datascripts /test_chore_phosa.py
ac5113's picture
added files
99a05f0
raw
history blame
8.62 kB
# python test_chore_phosa.py --dataset_root /is/cluster/fast/achatterjee/CHORE_DECO/phosa_test --npz_file /is/cluster/fast/achatterjee/Datasets/hot_phosa/hot_phosa_test.npz --json_file /ps/scratch/ps_shared/stripathi/deco/4agniv/hot_phosa_split/imgnames_per_object_dict.json --th 0.05
import numpy as np
import pandas as pd
import torch
import os
import os.path as osp
import json
import trimesh
import argparse
import collections
from tqdm import tqdm
DIST_MATRIX = np.load('/is/cluster/fast/achatterjee/dca_contact/data/smpl/smpl_neutral_geodesic_dist.npy')
def precision_recall_f1score(gt, pred):
"""
Compute precision, recall, and f1
"""
precision = np.zeros(gt.shape[0])
recall = np.zeros(gt.shape[0])
f1 = np.zeros(gt.shape[0])
for b in range(gt.shape[0]):
tp_num = gt[b, pred[b, :] >= 0.5].sum()
precision_denominator = (pred[b, :] >= 0.5).sum()
recall_denominator = (gt[b, :]).sum()
precision_ = tp_num / precision_denominator
recall_ = tp_num / recall_denominator
if precision_denominator == 0: # if no pred
precision_ = 1.
recall_ = 0.
f1_ = 0.
elif recall_denominator == 0: # if no GT
precision_ = 0.
recall_ = 1.
f1_ = 0.
elif (precision_ + recall_) <= 1e-10: # to avoid precision issues
precision_= 0.
recall_= 0.
f1_ = 0.
else:
f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
precision[b] = precision_
recall[b] = recall_
f1[b] = f1_
return precision, recall, f1
def det_error_metric(gt, pred):
dist_matrix = torch.tensor(DIST_MATRIX)
false_positive_dist = torch.zeros(gt.shape[0])
false_negative_dist = torch.zeros(gt.shape[0])
for b in range(gt.shape[0]):
gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix
error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns
false_positive_dist_ = error_matrix.min(dim=1)[0].mean()
false_negative_dist_ = error_matrix.min(dim=0)[0].mean()
false_positive_dist[b] = false_positive_dist_
false_negative_dist[b] = false_negative_dist_
return false_positive_dist, false_negative_dist
def main(args):
with open(args.json_file, 'r') as fp:
img_list_dict = json.load(fp)
d = np.load(args.npz_file)
img_count = 0
tot_pre = 0
tot_rec = 0
tot_f1 = 0
tot_fp_err = 0
# for i, img in tqdm(enumerate(d['imgname']), dynamic_ncols=True):
# gt = d['contact_label'][i]
# gt = np.reshape(gt, (1, -1))
# imgname = osp.basename(img)
# for k, v in img_list_dict.items():
# if imgname in v:
# object_class = k
# break
# # pred_path = osp.join(args.dataset_root, object_class, imgname, 'chore-test', f'k1.smpl_colored_{args.th}.ply')
# pred_path = osp.join(args.dataset_root, object_class, imgname, f'k1.smpl_colored_{args.th}.obj')
# if not osp.exists(pred_path):
# # print(pred_path)
# continue
# pred_mesh = trimesh.load(pred_path, process=False)
# pred = np.zeros((1, 6890))
# for i in range(pred_mesh.visual.vertex_colors.shape[0]):
# c = pred_mesh.visual.vertex_colors[i]
# if collections.Counter(c) == collections.Counter([255, 0, 0, 255]):
# pred[0, i] = 1
# pre, rec, f1 = precision_recall_f1score(gt, pred)
# fp_err, _ = det_error_metric(pred, gt)
# tot_pre += pre.sum()
# tot_rec += rec.sum()
# tot_f1 += f1.sum()
# tot_fp_err += fp_err.numpy().sum()
# img_count += 1
# print(f'Dataset size: {img_count}')
# print(f'Threshold: {args.th}\n')
# print(f'Test Precision: {tot_pre/img_count}')
# print(f'Test Recall: {tot_rec/img_count}')
# print(f'Test F1 Score: {tot_f1/img_count}')
# print(f'Test FP Error: {tot_fp_err/img_count}')
# Part of code for checking a random image
img_search = osp.join('/ps/project/datasets/HOT/Contact_Data/images/training', 'vcoco_000000542163.jpg')
i = np.where(d['imgname'] == img_search)[0][0]
print(i)
# i = 50
img = d['imgname'][i]
gt = d['contact_label'][i]
gt = np.reshape(gt, (1, -1))
imgname = osp.basename(img)
print(f'Image: {imgname}')
for k, v in img_list_dict.items():
if imgname in v:
object_class = k
break
print(f'Object: {object_class}')
# pred_path = osp.join(args.dataset_root, object_class, imgname, 'chore-test', f'k1.smpl_colored_{args.th}.ply')
pred_path = osp.join(args.dataset_root, object_class, imgname, f'k1.smpl_colored_{args.th}.obj')
if not osp.exists(pred_path):
print(f'Missing file: {pred_path}')
pred_mesh = trimesh.load(pred_path, process=False)
pred = np.zeros((1, 6890))
for i in range(pred_mesh.visual.vertex_colors.shape[0]):
c = pred_mesh.visual.vertex_colors[i]
if collections.Counter(c) == collections.Counter([255, 0, 0, 255]):
pred[0, i] = 1
pre, rec, f1 = precision_recall_f1score(gt, pred)
fp_err, _ = det_error_metric(pred, gt)
tot_pre += pre.sum()
tot_rec += rec.sum()
tot_f1 += f1.sum()
tot_fp_err += fp_err.numpy().sum()
print(f'Test Precision: {tot_pre}')
print(f'Test Recall: {tot_rec}')
print(f'Test F1 Score: {tot_f1}')
print(f'Test FP Error: {tot_fp_err}')
# best_pre = 0
# best_rec = 0
# best_f1 = 0
# best_fp_err = 0
# best_imgname = ''
# best_obj = ''
# for i, img in tqdm(enumerate(d['imgname']), dynamic_ncols=True):
# gt = d['contact_label'][i]
# gt = np.reshape(gt, (1, -1))
# imgname = osp.basename(img)
# for k, v in img_list_dict.items():
# if imgname in v:
# object_class = k
# break
# # pred_path = osp.join(args.dataset_root, object_class, imgname, 'chore-test', f'k1.smpl_colored_{args.th}.ply')
# pred_path = osp.join(args.dataset_root, object_class, imgname, f'k1.smpl_colored_{args.th}.obj')
# if not osp.exists(pred_path):
# # print(pred_path)
# continue
# pred_mesh = trimesh.load(pred_path, process=False)
# pred = np.zeros((1, 6890))
# for i in range(pred_mesh.visual.vertex_colors.shape[0]):
# c = pred_mesh.visual.vertex_colors[i]
# if collections.Counter(c) == collections.Counter([255, 0, 0, 255]):
# pred[0, i] = 1
# pre, rec, f1 = precision_recall_f1score(gt, pred)
# fp_err, _ = det_error_metric(pred, gt)
# tot_pre += pre.sum()
# tot_rec += rec.sum()
# tot_f1 += f1.sum()
# tot_fp_err += fp_err.numpy().sum()
# if f1.sum() > best_f1:
# best_pre = pre.sum()
# best_rec = rec.sum()
# best_f1 = f1.sum()
# best_fp_err = fp_err.numpy().sum()
# best_imgname = imgname
# best_obj = object_class
# img_count += 1
# print(f'Dataset size: {img_count}')
# print(f'Threshold: {args.th}\n')
# print(f'Test Precision: {tot_pre/img_count}')
# print(f'Test Recall: {tot_rec/img_count}')
# print(f'Test F1 Score: {tot_f1/img_count}')
# print(f'Test FP Error: {tot_fp_err/img_count}\n')
# print(f'Best Precision: {best_pre}')
# print(f'Best Recall: {best_rec}')
# print(f'Best F1 Score: {best_f1}')
# print(f'Best FP Error: {best_fp_err}')
# print(f'Best Image: {best_imgname} and Object: {best_obj}')
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', type=str, default='/is/cluster/fast/achatterjee/CHORE_DECO/hot_test')
parser.add_argument('--npz_file', type=str, default='/is/cluster/fast/achatterjee/Datasets/hot_behave/hot_behave_test_reduced.npz')
parser.add_argument('--json_file', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot_behave_split/imgnames_per_object_dict_reduced.json')
parser.add_argument('--th', type=float)
args = parser.parse_args()
main(args)