# Copyright (c) OpenMMLab. All rights reserved. import glob import os import os.path as osp import shutil import ssl import urllib.request as request from typing import Dict, List, Optional, Tuple from mmengine import mkdir_or_exist from mmocr.registry import DATA_OBTAINERS from mmocr.utils import check_integrity, is_archive ssl._create_default_https_context = ssl._create_unverified_context @DATA_OBTAINERS.register_module() class NaiveDataObtainer: """A naive pipeline for obtaining dataset. download -> extract -> move Args: files (list[dict]): A list of file information. cache_path (str): The path to cache the downloaded files. data_root (str): The root path of the dataset. It is usually set auto- matically and users do not need to set it manually in config file in most cases. task (str): The task of the dataset. It is usually set automatically and users do not need to set it manually in config file in most cases. """ def __init__(self, files: List[Dict], cache_path: str, data_root: str, task: str) -> None: self.files = files self.cache_path = cache_path self.data_root = data_root self.task = task mkdir_or_exist(self.data_root) mkdir_or_exist(osp.join(self.data_root, f'{task}_imgs')) mkdir_or_exist(osp.join(self.data_root, 'annotations')) mkdir_or_exist(self.cache_path) def __call__(self): for file in self.files: save_name = file.get('save_name', None) url = file.get('url', None) md5 = file.get('md5', None) download_path = osp.join( self.cache_path, osp.basename(url) if save_name is None else save_name) # Download required files if not check_integrity(download_path, md5): self.download(url=url, dst_path=download_path) # Extract downloaded zip files to data root self.extract(src_path=download_path, dst_path=self.data_root) # Move & Rename dataset files if 'mapping' in file: self.move(mapping=file['mapping']) self.clean() def download(self, url: Optional[str], dst_path: str) -> None: """Download file from given url with progress bar. Args: url (str): The url to download the file. dst_path (str): The destination path to save the file. """ def progress(down: float, block: float, size: float) -> None: """Show download progress. Args: down (float): Downloaded size. block (float): Block size. size (float): Total size of the file. """ percent = min(100. * down * block / size, 100) file_name = osp.basename(dst_path) print(f'\rDownloading {file_name}: {percent:.2f}%', end='') if url is None and not osp.exists(dst_path): raise FileNotFoundError( 'Direct url is not available for this dataset.' ' Please manually download the required files' ' following the guides.') if url.startswith('magnet'): raise NotImplementedError('Please use any BitTorrent client to ' 'download the following magnet link to ' f'{osp.abspath(dst_path)} and ' f'try again.\nLink: {url}') print('Downloading...') print(f'URL: {url}') print(f'Destination: {osp.abspath(dst_path)}') print('If you stuck here for a long time, please check your network, ' 'or manually download the file to the destination path and ' 'run the script again.') request.urlretrieve(url, dst_path, progress) print('') def extract(self, src_path: str, dst_path: str, delete: bool = False) -> None: """Extract zip/tar.gz files. Args: src_path (str): Path to the zip file. dst_path (str): Path to the destination folder. delete (bool, optional): Whether to delete the zip file. Defaults to False. """ if not is_archive(src_path): # Copy the file to the destination folder if it is not a zip if osp.isfile(src_path): shutil.copy(src_path, dst_path) else: shutil.copytree(src_path, dst_path) return zip_name = osp.basename(src_path).split('.')[0] if dst_path is None: dst_path = osp.join(osp.dirname(src_path), zip_name) else: dst_path = osp.join(dst_path, zip_name) extracted = False if osp.exists(dst_path): name = set(os.listdir(dst_path)) if '.finish' in name: extracted = True elif '.finish' not in name and len(name) > 0: while True: c = input(f'{dst_path} already exists when extracting ' '{zip_name}, unzip again? (y/N) ') or 'N' if c.lower() in ['y', 'n']: extracted = c == 'n' break if extracted: open(osp.join(dst_path, '.finish'), 'w').close() print(f'{zip_name} has been extracted. Skip') return mkdir_or_exist(dst_path) print(f'Extracting: {osp.basename(src_path)}') if src_path.endswith('.zip'): try: import zipfile except ImportError: raise ImportError( 'Please install zipfile by running "pip install zipfile".') with zipfile.ZipFile(src_path, 'r') as zip_ref: zip_ref.extractall(dst_path) elif src_path.endswith('.tar.gz') or src_path.endswith('.tar'): if src_path.endswith('.tar.gz'): mode = 'r:gz' elif src_path.endswith('.tar'): mode = 'r:' try: import tarfile except ImportError: raise ImportError( 'Please install tarfile by running "pip install tarfile".') with tarfile.open(src_path, mode) as tar_ref: tar_ref.extractall(dst_path) open(osp.join(dst_path, '.finish'), 'w').close() if delete: os.remove(src_path) def move(self, mapping: List[Tuple[str, str]]) -> None: """Rename and move dataset files one by one. Args: mapping (List[Tuple[str, str]]): A list of tuples, each tuple contains the source file name and the destination file name. """ for src, dst in mapping: src = osp.join(self.data_root, src) dst = osp.join(self.data_root, dst) if '*' in src: mkdir_or_exist(dst) for f in glob.glob(src): if not osp.exists( osp.join(dst, osp.relpath(f, self.data_root))): shutil.move(f, dst) elif osp.exists(src) and not osp.exists(dst): mkdir_or_exist(osp.dirname(dst)) shutil.move(src, dst) def clean(self) -> None: """Remove empty dirs.""" for root, dirs, files in os.walk(self.data_root, topdown=False): if not files and not dirs: os.rmdir(root)