Spaces:
Runtime error
Runtime error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import cv2 | |
import numpy as np | |
import time | |
import paddle | |
import paddle.nn.functional as F | |
from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar | |
from ppmatting.metrics import metrics_class_dict | |
np.set_printoptions(suppress=True) | |
def save_alpha_pred(alpha, path): | |
""" | |
The value of alpha is range [0, 1], shape should be [h,w] | |
""" | |
dirname = os.path.dirname(path) | |
if not os.path.exists(dirname): | |
os.makedirs(dirname) | |
alpha = (alpha).astype('uint8') | |
cv2.imwrite(path, alpha) | |
def reverse_transform(alpha, trans_info): | |
"""recover pred to origin shape""" | |
for item in trans_info[::-1]: | |
if item[0][0] == 'resize': | |
h, w = item[1][0], item[1][1] | |
alpha = F.interpolate(alpha, [h, w], mode='bilinear') | |
elif item[0][0] == 'padding': | |
h, w = item[1][0], item[1][1] | |
alpha = alpha[:, :, 0:h, 0:w] | |
else: | |
raise Exception("Unexpected info '{}' in im_info".format(item[0])) | |
return alpha | |
def evaluate(model, | |
eval_dataset, | |
num_workers=0, | |
print_detail=True, | |
save_dir='output/results', | |
save_results=True, | |
metrics='sad'): | |
model.eval() | |
nranks = paddle.distributed.ParallelEnv().nranks | |
local_rank = paddle.distributed.ParallelEnv().local_rank | |
if nranks > 1: | |
# Initialize parallel environment if not done. | |
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( | |
): | |
paddle.distributed.init_parallel_env() | |
loader = paddle.io.DataLoader( | |
eval_dataset, | |
batch_size=1, | |
drop_last=False, | |
num_workers=num_workers, | |
return_list=True, ) | |
total_iters = len(loader) | |
# Get metric instances and data saving | |
metrics_ins = {} | |
metrics_data = {} | |
if isinstance(metrics, str): | |
metrics = [metrics] | |
elif not isinstance(metrics, list): | |
metrics = ['sad'] | |
for key in metrics: | |
key = key.lower() | |
metrics_ins[key] = metrics_class_dict[key]() | |
metrics_data[key] = None | |
if print_detail: | |
logger.info("Start evaluating (total_samples: {}, total_iters: {})...". | |
format(len(eval_dataset), total_iters)) | |
progbar_val = progbar.Progbar( | |
target=total_iters, verbose=1 if nranks < 2 else 2) | |
reader_cost_averager = TimeAverager() | |
batch_cost_averager = TimeAverager() | |
batch_start = time.time() | |
img_name = '' | |
i = 0 | |
with paddle.no_grad(): | |
for iter, data in enumerate(loader): | |
reader_cost_averager.record(time.time() - batch_start) | |
alpha_pred = model(data) | |
alpha_pred = reverse_transform(alpha_pred, data['trans_info']) | |
alpha_pred = alpha_pred.numpy() | |
alpha_gt = data['alpha'].numpy() * 255 | |
trimap = data.get('ori_trimap') | |
if trimap is not None: | |
trimap = trimap.numpy().astype('uint8') | |
alpha_pred = np.round(alpha_pred * 255) | |
for key in metrics_ins.keys(): | |
metrics_data[key] = metrics_ins[key].update(alpha_pred, | |
alpha_gt, trimap) | |
if save_results: | |
alpha_pred_one = alpha_pred[0].squeeze() | |
if trimap is not None: | |
trimap = trimap.squeeze().astype('uint8') | |
alpha_pred_one[trimap == 255] = 255 | |
alpha_pred_one[trimap == 0] = 0 | |
save_name = data['img_name'][0] | |
name, ext = os.path.splitext(save_name) | |
if save_name == img_name: | |
save_name = name + '_' + str(i) + ext | |
i += 1 | |
else: | |
img_name = save_name | |
save_name = name + '_' + str(i) + ext | |
i = 1 | |
save_alpha_pred(alpha_pred_one, | |
os.path.join(save_dir, save_name)) | |
batch_cost_averager.record( | |
time.time() - batch_start, num_samples=len(alpha_gt)) | |
batch_cost = batch_cost_averager.get_average() | |
reader_cost = reader_cost_averager.get_average() | |
if local_rank == 0 and print_detail: | |
show_list = [(k, v) for k, v in metrics_data.items()] | |
show_list = show_list + [('batch_cost', batch_cost), | |
('reader cost', reader_cost)] | |
progbar_val.update(iter + 1, show_list) | |
reader_cost_averager.reset() | |
batch_cost_averager.reset() | |
batch_start = time.time() | |
for key in metrics_ins.keys(): | |
metrics_data[key] = metrics_ins[key].evaluate() | |
log_str = '[EVAL] ' | |
for key, value in metrics_data.items(): | |
log_str = log_str + key + ': {:.4f}, '.format(value) | |
log_str = log_str[:-2] | |
logger.info(log_str) | |
return metrics_data | |