# Copyright (c) OpenMMLab. All rights reserved. import os import urllib import numpy as np import torch from mmengine.utils import scandir from prettytable import PrettyTable # from mmyolo.models import RepVGGBlock IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def switch_to_deploy(model): """Model switch to deploy status.""" for layer in model.modules(): if isinstance(layer, RepVGGBlock): layer.switch_to_deploy() print('Switch model to deploy modality.') def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray: """Auto arrange image to image_column x N row. Args: image_list (list): cv2 image list. image_column (int): Arrange to N column. Default: 2. Return: (np.ndarray): image_column x N row merge image """ img_count = len(image_list) if img_count <= image_column: # no need to arrange image_show = np.concatenate(image_list, axis=1) else: # arrange image according to image_column image_row = round(img_count / image_column) fill_img_list = [np.ones(image_list[0].shape, dtype=np.uint8) * 255 ] * ( image_row * image_column - img_count) image_list.extend(fill_img_list) merge_imgs_col = [] for i in range(image_row): start_col = image_column * i end_col = image_column * (i + 1) merge_col = np.hstack(image_list[start_col:end_col]) merge_imgs_col.append(merge_col) # merge to one image image_show = np.vstack(merge_imgs_col) return image_show def get_file_list(source_root: str) -> [list, dict]: """Get file list. Args: source_root (str): image or video source path Return: source_file_path_list (list): A list for all source file. source_type (dict): Source type: file or url or dir. """ is_dir = os.path.isdir(source_root) is_url = source_root.startswith(('http:/', 'https:/')) is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS source_file_path_list = [] if is_dir: # when input source is dir for file in scandir(source_root, IMG_EXTENSIONS, recursive=True): source_file_path_list.append(os.path.join(source_root, file)) elif is_url: # when input source is url filename = os.path.basename( urllib.parse.unquote(source_root).split('?')[0]) file_save_path = os.path.join(os.getcwd(), filename) print(f'Downloading source file to {file_save_path}') torch.hub.download_url_to_file(source_root, file_save_path) source_file_path_list = [file_save_path] elif is_file: # when input source is single image source_file_path_list = [source_root] else: print('Cannot find image file.') source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file) return source_file_path_list, source_type def show_data_classes(data_classes): """When printing an error, all class names of the dataset.""" print('\n\nThe name of the class contained in the dataset:') data_classes_info = PrettyTable() data_classes_info.title = 'Information of dataset class' # List Print Settings # If the quantity is too large, 25 rows will be displayed in each column if len(data_classes) < 25: data_classes_info.add_column('Class name', data_classes) elif len(data_classes) % 25 != 0 and len(data_classes) > 25: col_num = int(len(data_classes) / 25) + 1 data_name_list = list(data_classes) for i in range(0, (col_num * 25) - len(data_classes)): data_name_list.append('') for i in range(0, len(data_name_list), 25): data_classes_info.add_column('Class name', data_name_list[i:i + 25]) # Align display data to the left data_classes_info.align['Class name'] = 'l' print(data_classes_info) def is_metainfo_lower(cfg): """Determine whether the custom metainfo fields are all lowercase.""" def judge_keys(dataloader_cfg): while 'dataset' in dataloader_cfg: dataloader_cfg = dataloader_cfg['dataset'] if 'metainfo' in dataloader_cfg: all_keys = dataloader_cfg['metainfo'].keys() all_is_lower = all([str(k).islower() for k in all_keys]) assert all_is_lower, f'The keys in dataset metainfo must be all lowercase, but got {all_keys}. ' \ f'Please refer to https://github.com/open-mmlab/mmyolo/blob/e62c8c4593/configs/yolov5/yolov5_s-v61_syncbn_fast_1xb4-300e_balloon.py#L8' # noqa judge_keys(cfg.get('train_dataloader', {})) judge_keys(cfg.get('val_dataloader', {})) judge_keys(cfg.get('test_dataloader', {}))