|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""COCO17 panoptic evaluation. |
|
|
|
jax.jit-compatible fork of the evaluator from evaluators/proj/uvim. |
|
""" |
|
import functools |
|
import itertools |
|
import json |
|
import os |
|
import tempfile |
|
import time |
|
from typing import Any |
|
import zipfile |
|
|
|
from absl import flags |
|
from absl import logging |
|
from big_vision import input_pipeline |
|
from big_vision import utils |
|
from big_vision.datasets import core as ds_core |
|
import big_vision.pp.builder as pp_builder |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from pycocotools.panopticapi import evaluation |
|
import panopticapi_converters.twochannels2panoptic_coco_format as converter |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
from tensorflow.io import gfile |
|
|
|
|
|
API = 'jit' |
|
|
|
ROOT = os.environ.get('COCO_DATA_DIR', '.') |
|
|
|
PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json' |
|
PANOPTIC_2017 = { |
|
'train': f'{ROOT}/panoptic_train2017.json', |
|
'validation': f'{ROOT}/panoptic_val2017.json', |
|
} |
|
|
|
PANOPTIC_GT_ZIP = { |
|
'train': f'{ROOT}/panoptic_train2017.zip', |
|
'validation': f'{ROOT}/panoptic_val2017.zip', |
|
} |
|
|
|
|
|
|
|
@functools.cache |
|
def _get_predict_fn(predict_fn, mesh=None): |
|
"""Wrapper for jit-compiled predict function.""" |
|
|
|
|
|
|
|
@functools.partial(jax.jit, |
|
out_shardings=jax.sharding.NamedSharding( |
|
mesh, jax.sharding.PartitionSpec())) |
|
def _run_predict_fn(train_state, batch): |
|
"""Run predict_fn and gather all outputs on all devices.""" |
|
y = predict_fn(train_state, batch) |
|
res = { |
|
'image/id': batch['image/id'], |
|
'mask': batch['_mask'], |
|
'y': jnp.stack([y['semantics'], y['instances']], axis=-1), |
|
} |
|
return res |
|
return _run_predict_fn |
|
|
|
|
|
class Evaluator: |
|
"""Panoptic segmentation evaluator: calls official COCO API.""" |
|
|
|
def __init__( |
|
self, |
|
predict_fn, |
|
pp_fn, |
|
batch_size, |
|
data=None, |
|
cache_final=True, |
|
cache_raw=False, |
|
prefetch=1, |
|
save_dir=None, |
|
*, |
|
devices, |
|
): |
|
"""Panoptic segmentation evaluator: calls official COCO API. |
|
|
|
Args: |
|
predict_fn: jit-compilable function, which accepts arbitrary dictionaries |
|
of parameters and data, where the data dictionary is produced by the |
|
`pp_fn`. It is expected to output a 2-channel mask, where the first |
|
channel encodes semantics, and the second channel encodes instance ids. |
|
pp_fn: Preprocessing function, sepcified as string. |
|
batch_size: Batch size. |
|
data: Dict specifying name and split of the data set. Defaults to the |
|
standard COCO (2017). |
|
cache_final: Whether to cache the data after preprocessing - see |
|
input_pipeline for details. |
|
cache_raw: Whether to cache the raw data - see input_pipline for details. |
|
prefetch: Number of batches to prefetch |
|
save_dir: Directory to save the results in. |
|
devices: List of jax devices. |
|
""" |
|
self.predict_fn = _get_predict_fn( |
|
predict_fn, jax.sharding.Mesh(devices, ('devices',))) |
|
|
|
data_specs = dict(name='coco/2017_panoptic', |
|
data_dir=None, split='validation') |
|
data_specs.update(data or {}) |
|
data = ds_core.get(**data_specs) |
|
self.dataset, self.steps = input_pipeline.make_for_inference( |
|
data.get_tfdata(ordered=True), batch_size=batch_size, |
|
num_ex_per_process=data.num_examples_per_process(), |
|
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), |
|
cache_final=cache_final, cache_raw=cache_raw) |
|
self.data_iter = input_pipeline.start_global( |
|
self.dataset, devices, prefetch) |
|
|
|
|
|
if jax.process_index() == 0: |
|
self.result_dir = tempfile.TemporaryDirectory() |
|
(self.gt_folder, self.gt_json, self.categories_json, |
|
self.remap, self.size_map) = _prepare_ground_truth( |
|
data_specs['name'], data_specs['split'], |
|
data_specs.get('data_dir')) |
|
if save_dir: |
|
self.save_dir = save_dir.format(workdir=flags.FLAGS.workdir) |
|
gfile.makedirs(self.save_dir) |
|
else: |
|
self.save_dir = None |
|
|
|
def _compute_png_predictions( |
|
self, train_state: Any) -> Any: |
|
"""Computes predictions and converts then to png to optimize memory use.""" |
|
count = 0 |
|
logging.info('Panoptic eval: running inference.') |
|
for batch in itertools.islice(self.data_iter, self.steps): |
|
out = self.predict_fn(train_state, batch) |
|
|
|
if jax.process_index(): |
|
continue |
|
|
|
out = jax.device_get(out) |
|
mask = out['mask'] |
|
pan_recs = out['y'][mask] |
|
ids = out['image/id'][mask] |
|
|
|
for pan_rec, image_id in zip(pan_recs, ids): |
|
sem = pan_rec[..., 0] |
|
ins = pan_rec[..., 1] |
|
|
|
sem_remapped = np.array(sem) |
|
for v in np.unique(sem): |
|
sem_remapped[sem == v] = self.remap[v] |
|
sem = sem_remapped |
|
|
|
pan_mask = np.stack([sem, ins, np.zeros_like(sem)], axis=-1) |
|
pan_mask = utils.put_cpu(pan_mask) |
|
pan_mask = _resize_nearest(pan_mask, self.size_map[image_id]) |
|
pan_mask_png = tf.io.encode_png(pan_mask.astype('uint8')).numpy() |
|
|
|
fname = f'{self.result_dir.name}/{image_id:012d}.png' |
|
with open(fname, 'wb') as f: |
|
f.write(pan_mask_png) |
|
count += 1 |
|
|
|
logging.log_every_n_seconds( |
|
logging.INFO, 'Panoptic eval: processed %i examples so far.', 30, |
|
count) |
|
|
|
if jax.process_index(): |
|
return None |
|
|
|
logging.info('Panoptic eval: inference done. Processed %d examples.', count) |
|
return self.result_dir |
|
|
|
def run(self, train_state): |
|
"""Run panoptic segmentation evaluation. |
|
|
|
Args: |
|
train_state: pytree containing the model parameters. |
|
|
|
Yields: |
|
Tuples consisting of metric name and value. |
|
""" |
|
|
|
result_dir = self._compute_png_predictions(train_state) |
|
|
|
if jax.process_index(): |
|
return |
|
|
|
if self.save_dir: |
|
gfile.RecursivelyCopyDir(result_dir.name, self.save_dir, overwrite=True) |
|
|
|
with tempfile.TemporaryDirectory() as pred_folder, \ |
|
tempfile.NamedTemporaryFile(mode='w') as pred_json: |
|
|
|
logging.info('Panoptic eval: running conversion.') |
|
converter.converter( |
|
source_folder=result_dir.name, |
|
images_json_file=self.gt_json, |
|
categories_json_file=self.categories_json, |
|
segmentations_folder=pred_folder, |
|
predictions_json_file=pred_json.name) |
|
logging.info('Panoptic eval: conversion done.') |
|
|
|
logging.info('Panoptic eval: running metrics computation.') |
|
res = evaluation.pq_compute(gt_json_file=self.gt_json, |
|
gt_folder=self.gt_folder, |
|
pred_json_file=pred_json.name, |
|
pred_folder=pred_folder) |
|
logging.info('Panoptic eval: metrics computation done.') |
|
|
|
for k in ['All', 'Stuff', 'Things']: |
|
for m in ['pq', 'rq', 'sq']: |
|
yield f'{k}_{m}', res[k][m] |
|
|
|
|
|
def _prepare_ground_truth(dataset, split, data_dir): |
|
if dataset == 'coco/2017_panoptic' and data_dir is None: |
|
return _prepare_ground_truth_from_zipfiles(split) |
|
else: |
|
return _prepare_ground_truth_from_dataset(dataset, split, data_dir) |
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
def _prepare_ground_truth_from_dataset(dataset, split, data_dir): |
|
"""Prepare ground truth from a tf.data.Dataset. |
|
|
|
Args: |
|
dataset: TFDS-compatible dataset specification. |
|
split: Data set split to use. |
|
data_dir: Folder containing the data |
|
|
|
Returns: |
|
A tuple containing the folder containing the ground-truth data, the |
|
ground truth annotations loaded from json, the categories loaded form json, |
|
a map for remapping, and a map mapping image id to image size. |
|
|
|
""" |
|
tfds_dataset = tfds.builder( |
|
dataset, data_dir=data_dir).as_dataset(split=split) |
|
|
|
categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE) |
|
with gfile.GFile(categories_json, 'rb') as f: |
|
categories = json.loads(f.read()) |
|
|
|
|
|
remap = {0: 0} |
|
with gfile.GFile(categories_json, 'r') as f: |
|
remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}} |
|
|
|
gt_folder = tempfile.mkdtemp() |
|
gfile.makedirs(gt_folder) |
|
size_map = {} |
|
annotations = [] |
|
images = [] |
|
for example in tfds_dataset: |
|
image_id = int(example['image/id']) |
|
panoptic_image = example['panoptic_image'] |
|
ann_ids = example['panoptic_objects']['id'] |
|
ann_labels = example['panoptic_objects']['label'] |
|
ann_iscrowd = example['panoptic_objects']['is_crowd'] |
|
ann_area = example['panoptic_objects']['area'] |
|
|
|
fname = f'{image_id:012d}.png' |
|
with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f: |
|
f.write(tf.io.encode_png(panoptic_image).numpy()) |
|
|
|
size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1]) |
|
|
|
segments_info = [] |
|
for i in range(len(ann_ids)): |
|
segments_info.append({ |
|
'id': int(ann_ids[i]), |
|
'category_id': remap[int(ann_labels[i] + 1)], |
|
'iscrowd': int(ann_iscrowd[i]), |
|
'area': int(ann_area[i]), |
|
}) |
|
|
|
annotations.append({ |
|
'file_name': str(fname), |
|
'image_id': int(image_id), |
|
'segments_info': segments_info |
|
}) |
|
images.append({ |
|
'id': image_id, |
|
'file_name': f'{image_id:012d}.jpg', |
|
}) |
|
|
|
|
|
gt_json = os.path.join(gt_folder, 'annotations.json') |
|
with gfile.GFile(gt_json, 'wb') as f: |
|
f.write(json.dumps({ |
|
'images': images, |
|
'annotations': annotations, |
|
'categories': categories, |
|
})) |
|
|
|
return gt_folder, gt_json, categories_json, remap, size_map |
|
|
|
|
|
def _prepare_ground_truth_from_zipfiles(split): |
|
"""Prepare ground truth from coco zip files. |
|
|
|
Args: |
|
split: dataset split to prepare ground truth for. |
|
|
|
Returns: |
|
A tuple containing the folder containing the ground-truth data, the ground |
|
truth annotations loaded from json, the categories loaded form json, a map |
|
for remapping, and a map mapping image id to image size. |
|
""" |
|
split_prefix = split.split('[')[0] |
|
if split_prefix not in ('train', 'validation'): |
|
raise ValueError(f'Split {split} not supported') |
|
|
|
|
|
|
|
gt_json = _make_local_copy(PANOPTIC_2017[split_prefix]) |
|
gt_folder = _make_local_unzip_copy(PANOPTIC_GT_ZIP[split_prefix]) |
|
categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE) |
|
image_ids = _list_image_ids('coco/2017_panoptic', split) |
|
|
|
gt_folder = os.path.join( |
|
gt_folder, 'panoptic_val2017' |
|
if split_prefix == 'validation' else 'panoptic_train2017') |
|
|
|
|
|
remap = {0: 0} |
|
with gfile.GFile(categories_json, 'r') as f: |
|
remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(json.load(f))}} |
|
|
|
|
|
with gfile.GFile(gt_json) as f: |
|
data = json.load(f) |
|
logging.info( |
|
'Panoptic eval: pre-filter %d annotations.', |
|
len(data['annotations']) |
|
) |
|
data['images'] = [x for x in data['images'] if x['id'] in image_ids] |
|
data['annotations'] = [ |
|
x for x in data['annotations'] if x['image_id'] in image_ids |
|
] |
|
logging.info( |
|
'Panoptic eval: post-filter %d annotations.', |
|
len(data['annotations']) |
|
) |
|
filtered_gt_json = tempfile.NamedTemporaryFile(delete=False).name |
|
with open(filtered_gt_json, 'w') as f: |
|
json.dump(data, f) |
|
|
|
|
|
size_map = {x['id']: (x['height'], x['width']) for x in data['images']} |
|
|
|
return gt_folder, filtered_gt_json, categories_json, remap, size_map |
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
def _list_image_ids(dataset, split): |
|
d = tfds.load(dataset, split=split).map(lambda x: x['image/id']) |
|
return frozenset(d.as_numpy_iterator()) |
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
def _make_local_copy(fname) -> str: |
|
start = time.monotonic() |
|
local_file = tempfile.NamedTemporaryFile(delete=False) |
|
gfile.copy(fname, local_file.name, overwrite=True) |
|
logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start) |
|
return local_file.name |
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
def _make_local_unzip_copy(fname) -> str: |
|
start = time.monotonic() |
|
folder = tempfile.mkdtemp() |
|
with tempfile.NamedTemporaryFile() as tmp_zip_file: |
|
gfile.copy(fname, tmp_zip_file.name, overwrite=True) |
|
with zipfile.ZipFile(tmp_zip_file.name, 'r') as f: |
|
f.extractall(folder) |
|
logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start) |
|
return folder |
|
|
|
|
|
@utils.jit_cpu(static_argnums=(1,)) |
|
def _resize_nearest(image, shape): |
|
return jax.image.resize(image, shape + image.shape[-1:], 'nearest') |
|
|