|
import os |
|
import json |
|
import glob |
|
import torch |
|
import datetime |
|
import argparse |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import pycocotools.mask as mask_util |
|
def create_image_info(image_id, file_name, image_size, |
|
date_captured=datetime.datetime.utcnow().isoformat(' '), |
|
license_id=1, coco_url="", flickr_url=""): |
|
"""Return image_info in COCO style |
|
Args: |
|
image_id: the image ID |
|
file_name: the file name of each image |
|
image_size: image size in the format of (width, height) |
|
date_captured: the date this image info is created |
|
license: license of this image |
|
coco_url: url to COCO images if there is any |
|
flickr_url: url to flickr if there is any |
|
""" |
|
image_info = { |
|
"id": image_id, |
|
"file_name": file_name, |
|
"width": image_size[0], |
|
"height": image_size[1], |
|
"date_captured": date_captured, |
|
"license": license_id, |
|
"coco_url": coco_url, |
|
"flickr_url": flickr_url |
|
} |
|
return image_info |
|
|
|
|
|
def create_annotation_info(annotation_id, image_id, category_info, binary_mask, |
|
image_size=None, bounding_box=None): |
|
"""Return annotation info in COCO style |
|
Args: |
|
annotation_id: the annotation ID |
|
image_id: the image ID |
|
category_info: the information on categories |
|
binary_mask: a 2D binary numpy array where '1's represent the object |
|
file_name: the file name of each image |
|
image_size: image size in the format of (width, height) |
|
bounding_box: the bounding box for detection task. If bounding_box is not provided, |
|
we will generate one according to the binary mask. |
|
""" |
|
upper = np.max(binary_mask) |
|
lower = np.min(binary_mask) |
|
thresh = upper / 2.0 |
|
binary_mask[binary_mask > thresh] = upper |
|
binary_mask[binary_mask <= thresh] = lower |
|
if image_size is not None: |
|
binary_mask = resize_binary_mask(binary_mask.astype(np.uint8), image_size) |
|
|
|
binary_mask_encoded = mask_util.encode(np.asfortranarray(binary_mask.astype(np.uint8))) |
|
|
|
area = mask_util.area(binary_mask_encoded) |
|
if area < 1: |
|
return None |
|
|
|
if bounding_box is None: |
|
bounding_box = mask_util.toBbox(binary_mask_encoded) |
|
|
|
rle = mask_util.encode(np.array(binary_mask[...,None], order="F", dtype="uint8"))[0] |
|
rle['counts'] = rle['counts'].decode('ascii') |
|
segmentation = rle |
|
|
|
annotation_info = { |
|
"id": annotation_id, |
|
"image_id": image_id, |
|
"category_id": category_info["id"], |
|
"iscrowd": 0, |
|
"area": area.tolist(), |
|
"bbox": bounding_box.tolist(), |
|
"segmentation": segmentation, |
|
"width": binary_mask.shape[1], |
|
"height": binary_mask.shape[0], |
|
} |
|
|
|
return annotation_info |
|
|
|
|
|
INFO = { |
|
"description": "ImageNet-1K: pseudo-masks with MaskCut", |
|
"url": "https://github.com/facebookresearch/CutLER", |
|
"version": "1.0", |
|
"year": 2023, |
|
"contributor": "Xudong Wang", |
|
"date_created": datetime.datetime.utcnow().isoformat(' ') |
|
} |
|
|
|
LICENSES = [ |
|
{ |
|
"id": 1, |
|
"name": "Apache License", |
|
"url": "https://github.com/facebookresearch/CutLER/blob/main/LICENSE" |
|
} |
|
] |
|
|
|
|
|
CATEGORIES = [ |
|
{ |
|
'id': 1, |
|
'name': 'fg', |
|
'supercategory': 'fg', |
|
}, |
|
] |
|
|
|
convert = lambda text: int(text) if text.isdigit() else text.lower() |
|
natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] |
|
|
|
output = { |
|
"info": INFO, |
|
"licenses": LICENSES, |
|
"categories": CATEGORIES, |
|
"images": [], |
|
"annotations": []} |
|
|
|
category_info = { |
|
"is_crowd": 0, |
|
"id": 1 |
|
} |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser('Merge pytorch results file into json') |
|
|
|
parser.add_argument('--base-dir', type=str, |
|
default='annotations/', |
|
help='Dir to the generated annotation .pt files with CWM') |
|
parser.add_argument('--save-path', type=str, default="coco_train_fixsize480_N3.json", |
|
help='Path to save the merged annotation file') |
|
args = parser.parse_args() |
|
|
|
file_list = glob.glob(os.path.join(args.base_dir, '*', '*')) |
|
|
|
ann_file = '/ccn2/u/honglinc/datasets/coco/annotations/instances_train2017.json' |
|
with open(ann_file, 'r') as file: |
|
gt_json = json.load(file) |
|
|
|
image_id, segmentation_id = 1, 1 |
|
image_names = [] |
|
for file_name in file_list: |
|
print('processing file name', file_name) |
|
|
|
data = torch.load(file_name) |
|
|
|
for img_name, mask_list in data.items(): |
|
|
|
for img in gt_json['images']: |
|
if img['file_name'] == img_name: |
|
height = img['height'] |
|
width = img['width'] |
|
break |
|
|
|
flag = img_name not in image_names |
|
if flag: |
|
image_info = create_image_info( |
|
image_id, img_name, (height, width, 3)) |
|
output["images"].append(image_info) |
|
image_names.append(img_name) |
|
|
|
|
|
for mask in mask_list: |
|
|
|
|
|
if mask.sum() == 0: |
|
continue |
|
pseudo_mask = F.interpolate(mask.float(), size=(height, width), mode='bicubic') > 0.5 |
|
pseudo_mask = pseudo_mask[0,0].numpy() |
|
annotation_info = create_annotation_info( |
|
segmentation_id, image_id, category_info, pseudo_mask.astype(np.uint8), None) |
|
if annotation_info is not None: |
|
output["annotations"].append(annotation_info) |
|
segmentation_id += 1 |
|
if flag: |
|
image_id += 1 |
|
print(image_id, segmentation_id) |
|
|
|
|
|
with open(args.save_path, 'w') as output_json_file: |
|
json.dump(output, output_json_file) |
|
print(f'dumping {args.save_path}') |
|
print("Done: {} images; {} anns.".format(len(output['images']), len(output['annotations']))) |
|
|