""" Tools Script ver: Feb 22rd 20:00 """ import os import shutil import torch import numpy as np from collections import OrderedDict # Tools def del_file(filepath): """ clear all items within a folder :param filepath: folder path :return: """ del_list = os.listdir(filepath) for f in del_list: file_path = os.path.join(filepath, f) if os.path.isfile(file_path): os.remove(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) def to_2tuple(input): if type(input) is tuple: if len(input) == 2: return input else: if len(input) > 2: output = (input[0], input[1]) return output elif len(input) == 1: output = (input[0], input[0]) return output else: print('cannot handle none tuple') else: if type(input) is list: if len(input) == 2: output = (input[0], input[1]) return output else: if len(input) > 2: output = (input[0], input[1]) return output elif len(input) == 1: output = (input[0], input[0]) return output else: print('cannot handle none list') elif type(input) is int: output = (input, input) return output else: print('cannot handle ', type(input)) raise ('cannot handle ', type(input)) def find_all_files(root, suffix=None): """ Return a list of file paths ended with specific suffix """ res = [] if type(suffix) is tuple or type(suffix) is list: for root, _, files in os.walk(root): for f in files: if suffix is not None: status = 0 for i in suffix: if not f.endswith(i): pass else: status = 1 break if status == 0: continue res.append(os.path.join(root, f)) return res elif type(suffix) is str or suffix is None: for root, _, files in os.walk(root): for f in files: if suffix is not None and not f.endswith(suffix): continue res.append(os.path.join(root, f)) return res else: print('type of suffix is not legal :', type(suffix)) return -1 # Transfer state_dict by removing misalignment def FixStateDict(state_dict, remove_key_head=None): """ Obtain a fixed state_dict by removing misalignment :param state_dict: model state_dict of OrderedDict() :param remove_key_head: the str or list of strings need to be remove by startswith """ if remove_key_head is None: return state_dict elif type(remove_key_head) == str: keys = [] for k, v in state_dict.items(): if k.startswith(remove_key_head): # 将‘arc’开头的key过滤掉,这里是要去除的层的key continue keys.append(k) elif type(remove_key_head) == list: keys = [] for k, v in state_dict.items(): jump = False for a_remove_key_head in remove_key_head: if k.startswith(a_remove_key_head): # 将‘arc’开头的key过滤掉,这里是要去除的层的key jump = True break if jump: continue else: keys.append(k) else: print('erro in defining remove_key_head !') return -1 new_state_dict = OrderedDict() for k in keys: new_state_dict[k] = state_dict[k] return new_state_dict def setup_seed(seed): # setting up the random seed import random torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True