|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for segmentation.""" |
|
|
|
import functools |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import PIL.Image |
|
|
|
from tensorflow.io import gfile |
|
|
|
|
|
|
|
|
|
API = 'jit' |
|
|
|
|
|
def _inrange(a, min_value, max_value): |
|
return (np.clip(a, min_value, max_value) == a).all() |
|
|
|
|
|
def _area(y1, x1, y2, x2): |
|
return max(x2 - x1, 0.0) * max(y2 - y1, 0.0) |
|
|
|
|
|
class Evaluator: |
|
"""Evaluator for instance segmentation.""" |
|
|
|
def __init__(self, predict_fn, tokenizer, |
|
model='oi', det_ious=(0.5, 0.75), |
|
*, devices, **kw): |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={'prefix', 'suffix', 'objects/mask', 'objects/bbox'}, |
|
devices=devices, **kw) |
|
|
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token) |
|
tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.loc0 = np.array(tok.to_int('<loc0000>')) |
|
self.seg0 = np.array(tok.to_int('<seg000>')) |
|
|
|
assert self.loc0.shape == (1,), self.loc0 |
|
assert self.seg0.shape == (1,), self.seg0 |
|
self.reconstruct_masks = get_reconstruct_masks(model) |
|
self.det_ious = det_ious |
|
|
|
def run(self, train_state): |
|
"""Does one evaluation run, yields metrics.""" |
|
ious = [] |
|
det_by_iou = {iou: [] for iou in self.det_ious} |
|
invalid = total = 0 |
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
decoded = self.decode(train_state, batch) |
|
|
|
not_padding = u.get_local_slice_from_fsarray(batch['_mask']) |
|
decoded = u.get_local_slice_from_fsarray(decoded)[not_padding] |
|
|
|
|
|
gt_masks = [gt[:, :, 0] > 0 for gt in batch['objects/mask'][not_padding]] |
|
gt_bbs = [gt for gt in batch['objects/bbox'][not_padding]] |
|
|
|
valid = [] |
|
tokens = np.zeros([decoded.shape[0], 4 + 16], np.int32) |
|
for i, dec in enumerate(decoded): |
|
|
|
t = np.r_[dec[:4] - self.loc0, dec[4:4 + 16] - self.seg0] |
|
if ( |
|
len(t) == 4 + 16 |
|
and _inrange(t[:4], 0, 1023) |
|
and _inrange(t[4:], 0, 127) |
|
and t[2] > t[0] and t[3] > t[1] |
|
): |
|
valid.append(True) |
|
tokens[i] = t |
|
else: |
|
valid.append(False) |
|
|
|
tocpu = lambda x: jax.device_put(x, jax.local_devices(backend='cpu')[0]) |
|
seg_indices = np.array(tokens[:, 4:]) |
|
mask64 = jax.device_get(self.reconstruct_masks(tocpu(seg_indices))) |
|
mask64 = mask64[..., 0] |
|
bbox = tokens[:, :4] / 1023 |
|
|
|
for v, m64, gtm, bb, gtbb in zip(valid, mask64, gt_masks, bbox, gt_bbs): |
|
|
|
total += 1 |
|
h, w = gtm.shape |
|
|
|
|
|
y1, x1, y2, x2 = bb |
|
gty1, gtx1, gty2, gtx2 = gtbb |
|
ibb = max(y1, gty1), max(x1, gtx1), min(y2, gty2), min(x2, gtx2) |
|
box_iou = _area(*ibb) / (_area(*bb) + _area(*gtbb) - _area(*ibb)) |
|
for iou_thresh in det_by_iou: |
|
det_by_iou[iou_thresh].append(iou_thresh <= box_iou) |
|
|
|
|
|
gt_area = gtm.sum() |
|
y1, x1, y2, x2 = map(int, (y1 * h, x1 * w, y2 * h, x2 * w)) |
|
|
|
|
|
if not v or x2 <= x1 or y2 <= y1: |
|
iou = 0.0 |
|
invalid += 1 |
|
else: |
|
mi = np.asarray(PIL.Image.fromarray(m64).resize( |
|
[x2 - x1, y2 - y1], resample=PIL.Image.BILINEAR |
|
)) |
|
mi = mi > 0.0 |
|
iarea = (gtm[y1:y2, x1:x2] & mi).sum() |
|
iou = iarea / (gt_area + mi.sum() - iarea) |
|
ious.append(iou) |
|
|
|
|
|
sum_ious, num_ious, sum_dets, num_dets, num_invalid, num = c.process_sum([ |
|
sum(ious), len(ious), |
|
{k: sum(v) for k, v in det_by_iou.items()}, |
|
{k: len(v) for k, v in det_by_iou.items()}, |
|
invalid, total |
|
]) |
|
|
|
yield 'miou', sum_ious / num_ious |
|
for k in sum_dets: |
|
yield f'boxacc/{k}', sum_dets[k] / num_dets[k] |
|
yield 'invalid', num_invalid |
|
yield 'total', num |
|
|
|
|
|
_KNOWN_MODELS = { |
|
|
|
'oi': 'gs://big_vision/paligemma/vae-oid.npz', |
|
} |
|
|
|
|
|
def _get_params(checkpoint): |
|
"""Converts PyTorch checkpoint to Flax params.""" |
|
|
|
def transp(kernel): |
|
return np.transpose(kernel, (2, 3, 1, 0)) |
|
|
|
def conv(name): |
|
return { |
|
'bias': checkpoint[name + '.bias'], |
|
'kernel': transp(checkpoint[name + '.weight']), |
|
} |
|
|
|
def resblock(name): |
|
return { |
|
'Conv_0': conv(name + '.0'), |
|
'Conv_1': conv(name + '.2'), |
|
'Conv_2': conv(name + '.4'), |
|
} |
|
|
|
return { |
|
'_embeddings': checkpoint['_vq_vae._embedding'], |
|
'Conv_0': conv('decoder.0'), |
|
'ResBlock_0': resblock('decoder.2.net'), |
|
'ResBlock_1': resblock('decoder.3.net'), |
|
'ConvTranspose_0': conv('decoder.4'), |
|
'ConvTranspose_1': conv('decoder.6'), |
|
'ConvTranspose_2': conv('decoder.8'), |
|
'ConvTranspose_3': conv('decoder.10'), |
|
'Conv_1': conv('decoder.12'), |
|
} |
|
|
|
|
|
def _quantized_values_from_codebook_indices(codebook_indices, embeddings): |
|
batch_size, num_tokens = codebook_indices.shape |
|
assert num_tokens == 16, codebook_indices.shape |
|
unused_num_embeddings, embedding_dim = embeddings.shape |
|
|
|
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) |
|
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) |
|
return encodings |
|
|
|
|
|
class ResBlock(nn.Module): |
|
features: int |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
original_x = x |
|
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) |
|
x = nn.relu(x) |
|
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) |
|
x = nn.relu(x) |
|
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) |
|
return x + original_x |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Upscales quantized vectors to mask.""" |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
num_res_blocks = 2 |
|
dim = 128 |
|
num_upsample_layers = 4 |
|
|
|
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) |
|
x = nn.relu(x) |
|
|
|
for _ in range(num_res_blocks): |
|
x = ResBlock(features=dim)(x) |
|
|
|
for _ in range(num_upsample_layers): |
|
x = nn.ConvTranspose( |
|
features=dim, |
|
kernel_size=(4, 4), |
|
strides=(2, 2), |
|
padding=2, |
|
transpose_kernel=True, |
|
)(x) |
|
x = nn.relu(x) |
|
dim //= 2 |
|
|
|
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) |
|
|
|
return x |
|
|
|
|
|
@functools.cache |
|
def get_reconstruct_masks(model): |
|
"""Reconstructs masks from codebook indices. |
|
|
|
Based on code from https://arxiv.org/abs/2301.02229 |
|
|
|
Verified in |
|
https://colab.research.google.com/drive/1AOr0cokOpM6-N9Z5HmxoeGxGj6jS37Vl |
|
|
|
Args: |
|
model: Model to use for conversion. |
|
|
|
Returns: |
|
A function that expects indices shaped `[B, 16]` of dtype int32, each |
|
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized |
|
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. |
|
""" |
|
def reconstruct_masks(codebook_indices): |
|
quantized = _quantized_values_from_codebook_indices( |
|
codebook_indices, params['_embeddings'] |
|
) |
|
return Decoder().apply({'params': params}, quantized) |
|
|
|
with gfile.GFile(_KNOWN_MODELS.get(model, model), 'rb') as f: |
|
params = _get_params(dict(np.load(f))) |
|
|
|
return jax.jit(reconstruct_masks, backend='cpu') |
|
|