Spaces:
Sleeping
Sleeping
import logging | |
import math | |
from typing import Dict | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import tqdm | |
from torch.utils.data import DataLoader | |
from lama.saicinpainting.evaluation.utils import move_to_device | |
LOGGER = logging.getLogger(__name__) | |
class InpaintingEvaluator(): | |
def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda', | |
integral_func=None, integral_title=None, clamp_image_range=None): | |
""" | |
:param dataset: torch.utils.data.Dataset which contains images and masks | |
:param scores: dict {score_name: EvaluatorScore object} | |
:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples | |
which are defined by share of area occluded by mask | |
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) | |
:param batch_size: batch_size for the dataloader | |
:param device: device to use | |
""" | |
self.scores = scores | |
self.dataset = dataset | |
self.area_grouping = area_grouping | |
self.bins = bins | |
self.device = torch.device(device) | |
self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size) | |
self.integral_func = integral_func | |
self.integral_title = integral_title | |
self.clamp_image_range = clamp_image_range | |
def _get_bin_edges(self): | |
bin_edges = np.linspace(0, 1, self.bins + 1) | |
num_digits = max(0, math.ceil(math.log10(self.bins)) - 1) | |
interval_names = [] | |
for idx_bin in range(self.bins): | |
start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \ | |
round(100 * bin_edges[idx_bin + 1], num_digits) | |
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) | |
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) | |
interval_names.append("{0}-{1}%".format(start_percent, end_percent)) | |
groups = [] | |
for batch in self.dataloader: | |
mask = batch['mask'] | |
batch_size = mask.shape[0] | |
area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1) | |
bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1 | |
# corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element | |
bin_indices[bin_indices == self.bins] = self.bins - 1 | |
groups.append(bin_indices) | |
groups = np.hstack(groups) | |
return groups, interval_names | |
def evaluate(self, model=None): | |
""" | |
:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch | |
:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or | |
name of the particular group arranged by area of mask (e.g. '10-20%') | |
and score statistics for the group as values. | |
""" | |
results = dict() | |
if self.area_grouping: | |
groups, interval_names = self._get_bin_edges() | |
else: | |
groups = None | |
for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'): | |
score.to(self.device) | |
with torch.no_grad(): | |
score.reset() | |
for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False): | |
batch = move_to_device(batch, self.device) | |
image_batch, mask_batch = batch['image'], batch['mask'] | |
if self.clamp_image_range is not None: | |
image_batch = torch.clamp(image_batch, | |
min=self.clamp_image_range[0], | |
max=self.clamp_image_range[1]) | |
if model is None: | |
assert 'inpainted' in batch, \ | |
'Model is None, so we expected precomputed inpainting results at key "inpainted"' | |
inpainted_batch = batch['inpainted'] | |
else: | |
inpainted_batch = model(image_batch, mask_batch) | |
score(inpainted_batch, image_batch, mask_batch) | |
total_results, group_results = score.get_value(groups=groups) | |
results[(score_name, 'total')] = total_results | |
if groups is not None: | |
for group_index, group_values in group_results.items(): | |
group_name = interval_names[group_index] | |
results[(score_name, group_name)] = group_values | |
if self.integral_func is not None: | |
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) | |
return results | |
def ssim_fid100_f1(metrics, fid_scale=100): | |
ssim = metrics[('ssim', 'total')]['mean'] | |
fid = metrics[('fid', 'total')]['mean'] | |
fid_rel = max(0, fid_scale - fid) / fid_scale | |
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3) | |
return f1 | |
def lpips_fid100_f1(metrics, fid_scale=100): | |
neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better | |
fid = metrics[('fid', 'total')]['mean'] | |
fid_rel = max(0, fid_scale - fid) / fid_scale | |
f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3) | |
return f1 | |
class InpaintingEvaluatorOnline(nn.Module): | |
def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted', | |
integral_func=None, integral_title=None, clamp_image_range=None): | |
""" | |
:param scores: dict {score_name: EvaluatorScore object} | |
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) | |
:param device: device to use | |
""" | |
super().__init__() | |
LOGGER.info(f'{type(self)} init called') | |
self.scores = nn.ModuleDict(scores) | |
self.image_key = image_key | |
self.inpainted_key = inpainted_key | |
self.bins_num = bins | |
self.bin_edges = np.linspace(0, 1, self.bins_num + 1) | |
num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1) | |
self.interval_names = [] | |
for idx_bin in range(self.bins_num): | |
start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \ | |
round(100 * self.bin_edges[idx_bin + 1], num_digits) | |
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) | |
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) | |
self.interval_names.append("{0}-{1}%".format(start_percent, end_percent)) | |
self.groups = [] | |
self.integral_func = integral_func | |
self.integral_title = integral_title | |
self.clamp_image_range = clamp_image_range | |
LOGGER.info(f'{type(self)} init done') | |
def _get_bins(self, mask_batch): | |
batch_size = mask_batch.shape[0] | |
area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy() | |
bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1) | |
return bin_indices | |
def forward(self, batch: Dict[str, torch.Tensor]): | |
""" | |
Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end | |
:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key) | |
""" | |
result = {} | |
with torch.no_grad(): | |
image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key] | |
if self.clamp_image_range is not None: | |
image_batch = torch.clamp(image_batch, | |
min=self.clamp_image_range[0], | |
max=self.clamp_image_range[1]) | |
self.groups.extend(self._get_bins(mask_batch)) | |
for score_name, score in self.scores.items(): | |
result[score_name] = score(inpainted_batch, image_batch, mask_batch) | |
return result | |
def process_batch(self, batch: Dict[str, torch.Tensor]): | |
return self(batch) | |
def evaluation_end(self, states=None): | |
""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or | |
name of the particular group arranged by area of mask (e.g. '10-20%') | |
and score statistics for the group as values. | |
""" | |
LOGGER.info(f'{type(self)}: evaluation_end called') | |
self.groups = np.array(self.groups) | |
results = {} | |
for score_name, score in self.scores.items(): | |
LOGGER.info(f'Getting value of {score_name}') | |
cur_states = [s[score_name] for s in states] if states is not None else None | |
total_results, group_results = score.get_value(groups=self.groups, states=cur_states) | |
LOGGER.info(f'Getting value of {score_name} done') | |
results[(score_name, 'total')] = total_results | |
for group_index, group_values in group_results.items(): | |
group_name = self.interval_names[group_index] | |
results[(score_name, group_name)] = group_values | |
if self.integral_func is not None: | |
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) | |
LOGGER.info(f'{type(self)}: reset scores') | |
self.groups = [] | |
for sc in self.scores.values(): | |
sc.reset() | |
LOGGER.info(f'{type(self)}: reset scores done') | |
LOGGER.info(f'{type(self)}: evaluation_end done') | |
return results | |