#!/usr/bin/env python3 # Scene Text Recognition Model Hub # Copyright 2022 Darwin Bautista # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import string import sys from dataclasses import dataclass from tqdm import tqdm import torch from strhub.data.module import SceneTextDataModule from strhub.models.utils import load_from_checkpoint, parse_model_args @dataclass class Result: dataset: str num_samples: int accuracy: float ned: float confidence: float label_length: float def print_results_table(results: list[Result], file=None): w = max(map(len, map(getattr, results, ['dataset'] * len(results)))) w = max(w, len('Dataset'), len('Combined')) print('| {:<{w}} | # samples | Accuracy | 1 - NED | Confidence | Label Length |'.format('Dataset', w=w), file=file) print('|:{:-<{w}}:|----------:|---------:|--------:|-----------:|-------------:|'.format('----', w=w), file=file) c = Result('Combined', 0, 0, 0, 0, 0) for res in results: c.num_samples += res.num_samples c.accuracy += res.num_samples * res.accuracy c.ned += res.num_samples * res.ned c.confidence += res.num_samples * res.confidence c.label_length += res.num_samples * res.label_length print( f'| {res.dataset:<{w}} | {res.num_samples:>9} | {res.accuracy:>8.2f} | {res.ned:>7.2f} ' f'| {res.confidence:>10.2f} | {res.label_length:>12.2f} |', file=file, ) c.accuracy /= c.num_samples c.ned /= c.num_samples c.confidence /= c.num_samples c.label_length /= c.num_samples print('|-{:-<{w}}-|-----------|----------|---------|------------|--------------|'.format('----', w=w), file=file) print( f'| {c.dataset:<{w}} | {c.num_samples:>9} | {c.accuracy:>8.2f} | {c.ned:>7.2f} ' f'| {c.confidence:>10.2f} | {c.label_length:>12.2f} |', file=file, ) @torch.inference_mode() def main(): parser = argparse.ArgumentParser() parser.add_argument('checkpoint', help="Model checkpoint (or 'pretrained=')") parser.add_argument('--data_root', default='data') parser.add_argument('--batch_size', type=int, default=512) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--cased', action='store_true', default=False, help='Cased comparison') parser.add_argument('--punctuation', action='store_true', default=False, help='Check punctuation') parser.add_argument('--new', action='store_true', default=False, help='Evaluate on new benchmark datasets') parser.add_argument('--rotation', type=int, default=0, help='Angle of rotation (counter clockwise) in degrees.') parser.add_argument('--device', default='cuda') args, unknown = parser.parse_known_args() kwargs = parse_model_args(unknown) charset_test = string.digits + string.ascii_lowercase if args.cased: charset_test += string.ascii_uppercase if args.punctuation: charset_test += string.punctuation kwargs.update({'charset_test': charset_test}) print(f'Additional keyword arguments: {kwargs}') model = load_from_checkpoint(args.checkpoint, **kwargs).eval().to(args.device) hp = model.hparams datamodule = SceneTextDataModule( args.data_root, '_unused_', hp.img_size, hp.max_label_length, hp.charset_train, hp.charset_test, args.batch_size, args.num_workers, False, rotation=args.rotation, ) test_set = SceneTextDataModule.TEST_BENCHMARK_SUB + SceneTextDataModule.TEST_BENCHMARK if args.new: test_set += SceneTextDataModule.TEST_NEW test_set = sorted(set(test_set)) results = {} max_width = max(map(len, test_set)) for name, dataloader in datamodule.test_dataloaders(test_set).items(): total = 0 correct = 0 ned = 0 confidence = 0 label_length = 0 for imgs, labels in tqdm(iter(dataloader), desc=f'{name:>{max_width}}'): res = model.test_step((imgs.to(model.device), labels), -1)['output'] total += res.num_samples correct += res.correct ned += res.ned confidence += res.confidence label_length += res.label_length accuracy = 100 * correct / total mean_ned = 100 * (1 - ned / total) mean_conf = 100 * confidence / total mean_label_length = label_length / total results[name] = Result(name, total, accuracy, mean_ned, mean_conf, mean_label_length) result_groups = { 'Benchmark (Subset)': SceneTextDataModule.TEST_BENCHMARK_SUB, 'Benchmark': SceneTextDataModule.TEST_BENCHMARK, } if args.new: result_groups.update({'New': SceneTextDataModule.TEST_NEW}) with open(args.checkpoint + '.log.txt', 'w') as f: for out in [f, sys.stdout]: for group, subset in result_groups.items(): print(f'{group} set:', file=out) print_results_table([results[s] for s in subset], out) print('\n', file=out) if __name__ == '__main__': main()