Spaces:
Sleeping
Sleeping
File size: 2,422 Bytes
99a05f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
'''
Fix paths to cropped images, partmasks and segmasks
'''
import argparse
import os
import numpy as np
from tqdm import tqdm
def convert_rich_npz(orig_npz, out_dir):
# go through all keys in the npz
# if the key is imgname, partmask or segmask, replace the path with the new path
# save the new npz
# structs we use
imgnames_ = []
poses_, shapes_, transls_ = [], [], []
cams_k_ = []
contact_label_ = []
scene_seg_, part_seg_ = [], []
# load the npz
npz = np.load(orig_npz)
for i in tqdm(range(len(npz['imgname']))):
if not os.path.exists(npz['imgname'][i]):
print(npz['imgname'][i])
continue
new_scene_seg = os.path.exists(npz['scene_seg'][i].replace('seg_masks_new', 'segmentation_masks'))
if not new_scene_seg:
print(new_scene_seg)
continue
if not os.path.exists(npz['part_seg'][i]):
print(npz['part_seg'][i])
continue
imgnames_.append(npz['imgname'][i])
poses_.append(npz['pose'][i])
transls_.append(npz['transl'][i])
shapes_.append(npz['shape'][i])
cams_k_.append(npz['cam_k'][i])
contact_label_.append(npz['contact_label'][i])
scene_seg_.append(npz['scene_seg'][i].replace('seg_masks_new', 'segmentation_masks'))
part_seg_.append(npz['part_seg'][i])
# save the new npz
out_dir = out_dir+'_cropped'
os.makedirs(out_dir, exist_ok=True)
out_file = os.path.join(out_dir, os.path.basename(args.orig_npz))
np.savez(out_file,
imgname=imgnames_,
pose=poses_,
transl=transls_,
shape=shapes_,
cam_k=cams_k_,
contact_label=contact_label_,
scene_seg=scene_seg_,
part_seg=part_seg_)
print('Saved to: ', out_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--orig_npz_dir', type=str, default='/is/cluster/fast/achatterjee/rich/scene_npzs/train')
parser.add_argument('--cluster_idx', type=int)
args = parser.parse_args()
# get all npz files in the directory
npz_files = [os.path.join(args.orig_npz_dir, f) for f in os.listdir(args.orig_npz_dir) if f.endswith('.npz')]
# get the npz file for this cluster
orig_npz = npz_files[args.cluster_idx]
convert_rich_npz(orig_npz, out_dir=args.orig_npz_dir) |