File size: 4,934 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) OpenMMLab. All rights reserved.
import os
import urllib

import numpy as np
import torch
from mmengine.utils import scandir
from prettytable import PrettyTable

# from mmyolo.models import RepVGGBlock

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
                  '.tiff', '.webp')


def switch_to_deploy(model):
    """Model switch to deploy status."""
    for layer in model.modules():
        if isinstance(layer, RepVGGBlock):
            layer.switch_to_deploy()

    print('Switch model to deploy modality.')


def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray:
    """Auto arrange image to image_column x N row.

    Args:
        image_list (list): cv2 image list.
        image_column (int): Arrange to N column. Default: 2.
    Return:
        (np.ndarray): image_column x N row merge image
    """
    img_count = len(image_list)
    if img_count <= image_column:
        # no need to arrange
        image_show = np.concatenate(image_list, axis=1)
    else:
        # arrange image according to image_column
        image_row = round(img_count / image_column)
        fill_img_list = [np.ones(image_list[0].shape, dtype=np.uint8) * 255
                         ] * (
                             image_row * image_column - img_count)
        image_list.extend(fill_img_list)
        merge_imgs_col = []
        for i in range(image_row):
            start_col = image_column * i
            end_col = image_column * (i + 1)
            merge_col = np.hstack(image_list[start_col:end_col])
            merge_imgs_col.append(merge_col)

        # merge to one image
        image_show = np.vstack(merge_imgs_col)

    return image_show


def get_file_list(source_root: str) -> [list, dict]:
    """Get file list.

    Args:
        source_root (str): image or video source path

    Return:
        source_file_path_list (list): A list for all source file.
        source_type (dict): Source type: file or url or dir.
    """
    is_dir = os.path.isdir(source_root)
    is_url = source_root.startswith(('http:/', 'https:/'))
    is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS

    source_file_path_list = []
    if is_dir:
        # when input source is dir
        for file in scandir(source_root, IMG_EXTENSIONS, recursive=True):
            source_file_path_list.append(os.path.join(source_root, file))
    elif is_url:
        # when input source is url
        filename = os.path.basename(
            urllib.parse.unquote(source_root).split('?')[0])
        file_save_path = os.path.join(os.getcwd(), filename)
        print(f'Downloading source file to {file_save_path}')
        torch.hub.download_url_to_file(source_root, file_save_path)
        source_file_path_list = [file_save_path]
    elif is_file:
        # when input source is single image
        source_file_path_list = [source_root]
    else:
        print('Cannot find image file.')

    source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file)

    return source_file_path_list, source_type


def show_data_classes(data_classes):
    """When printing an error, all class names of the dataset."""
    print('\n\nThe name of the class contained in the dataset:')
    data_classes_info = PrettyTable()
    data_classes_info.title = 'Information of dataset class'
    # List Print Settings
    # If the quantity is too large, 25 rows will be displayed in each column
    if len(data_classes) < 25:
        data_classes_info.add_column('Class name', data_classes)
    elif len(data_classes) % 25 != 0 and len(data_classes) > 25:
        col_num = int(len(data_classes) / 25) + 1
        data_name_list = list(data_classes)
        for i in range(0, (col_num * 25) - len(data_classes)):
            data_name_list.append('')
        for i in range(0, len(data_name_list), 25):
            data_classes_info.add_column('Class name',
                                         data_name_list[i:i + 25])

    # Align display data to the left
    data_classes_info.align['Class name'] = 'l'
    print(data_classes_info)


def is_metainfo_lower(cfg):
    """Determine whether the custom metainfo fields are all lowercase."""

    def judge_keys(dataloader_cfg):
        while 'dataset' in dataloader_cfg:
            dataloader_cfg = dataloader_cfg['dataset']
        if 'metainfo' in dataloader_cfg:
            all_keys = dataloader_cfg['metainfo'].keys()
            all_is_lower = all([str(k).islower() for k in all_keys])
            assert all_is_lower, f'The keys in dataset metainfo must be all lowercase, but got {all_keys}. ' \
                                 f'Please refer to https://github.com/open-mmlab/mmyolo/blob/e62c8c4593/configs/yolov5/yolov5_s-v61_syncbn_fast_1xb4-300e_balloon.py#L8' # noqa

    judge_keys(cfg.get('train_dataloader', {}))
    judge_keys(cfg.get('val_dataloader', {}))
    judge_keys(cfg.get('test_dataloader', {}))