import dataclasses import inspect import os import re import shutil from functools import partial from typing import Dict, Counter, Tuple import operator import numpy as np import requests from funcy import rcompose, identity from loguru import logger from tensorflow.keras import Model, Sequential from tensorflow.keras.layers import Layer from tensorflow.keras.losses import Loss def dir_decorator(fn): def deco(path, *args, **kwargs): if path is not None: folder = os.path.dirname(path) if folder: os.makedirs(folder, exist_ok=True) return fn(path, *args, **kwargs) return deco def extract_args(args): if il(args): if il(args[0]): return tuple(args[0]) return tuple(args) return tuple(lw(args)) def comp(*args): return rcompose(*extract_args(args)) def is_model(obj): return isinstance(obj, (Model, Layer, Loss)) def list_dir(root): for path in os.listdir(root): yield os.path.join(root, path) def is_int(obj, bool_=False): if not bool_ and isinstance(obj, bool): return False return isinstance(obj, (int, np.integer)) # return issubclass(type(obj),int) def is_float(obj): return isinstance(obj, (float, np.float)) def is_num(obj, bool_=False): return is_int(obj, bool_) or is_float(obj) def get_args(obj): try: return list(inspect.signature(obj).parameters.keys()) except Exception as e: logger.error(f"{e}. Returning 1.") return ["arg"] def il(data): return isinstance(data, (list, tuple)) def lw(data, none_empty=True, convert_tuple=True): if isinstance(data, list): return data elif isinstance(data, tuple) and convert_tuple: return list(data) if none_empty and data is None: return [] return [data] def dict_from_list(keys=None, values=None) -> Dict: if values is None: values = list(range(len(lw(keys)))) elif callable(values): values = [values(k) for k in keys] if keys is None: keys = list(range(len(lw(values)))) elif callable(keys): keys = [keys(v) for v in values] if not isinstance(values, dict): values = dict(zip(keys[:len(lw(values))], lw(values))) return values def call(fn, x=None): if len(get_args(fn)) > 0: if isinstance(x, dict): return fn(**x) elif il(x): return fn(*x) else: return fn(x) else: return fn() def get_str_comp(str_comp): if callable(str_comp): return str_comp elif isinstance(str_comp, str): if str_comp == "in": return lambda x, y: x in y elif str_comp in ["equal", "="]: return lambda x, y: x == y elif str_comp == "re": return lambda x, y: re.match(x, y) return lambda x, y: x == y def test(filters, x=None, str_comp="in"): str_comp = get_str_comp(str_comp) result = [] for f in lw(filters): if isinstance(f, str): result.append(str_comp(f, x)) elif callable(f): result.append(call(f, x)) else: result.append(f == x) return result test_all = comp([test, all]) test_any = comp([test, any]) def filter_keys(d, keys, reverse=False, str_comp="=", *args, **kwargs): fn = test_any if dataclasses.is_dataclass(keys): keys = dataclasses.asdict(keys) if isinstance(keys, dict): keys = list(keys.keys()) if dataclasses.is_dataclass(d): d = dataclasses.asdict(d) if reverse: fn = comp(test_any, lambda x: not x) return {k: v for k, v in d.items() if fn(keys, k, str_comp=str_comp, *args, **kwargs)} def none(x): return [] if x is None else x def if_images(data): if not hasattr(data, "shape"): return False shape = data.shape return len(shape) > 3 or (len(shape) == 3 and shape[-1] > 4) def download_file(url, path=None): if path is None: path = url.split("/")[-1] r = requests.get(url, stream=True) if r.status_code == 200: r.raw.decode_content = True with open(path, 'wb') as f: shutil.copyfileobj(r.raw, f) logger.info(f'File {url} sucessfully downloaded to {path} ') else: logger.info(f'File {url} couldn\'t be retreived') @dataclasses.dataclass class ImageType: mode: str = "RGB" range_: Tuple = (0, 255) channel: bool = True reverse: bool = False shape: bool = (None, 224, 224, 3) dtype: str = "uint8" counter: Counter = None def __eq__(self, other): return filter_keys(dataclasses.asdict(self), "counter", reverse=True) == filter_keys(dataclasses.asdict(other), "counter", reverse=True) def image_type(data, histogram=True): data = np.asarray(data) if not if_images(data): return False shape = data.shape channel = True range_ = data.min(), data.max() counter = None reverse = None if shape[-1] == 3: mode = "RGB" if histogram: counter = Counter(list(get_last_dim(data, 3).reshape(-1))) reverse = counter.most_common(1)[0][0] == range_[1] else: mode = "L" if shape[-1] != 1: channel = False if histogram: counter = Counter(list(get_last_dim(data).reshape(-1))) reverse = counter.most_common(1)[0][0] == range_[1] else: if histogram: counter = Counter(list(get_last_dim(data, 3).reshape(-1))) reverse = counter.most_common(1)[0][0] == range_[1] return ImageType(mode, range_, channel, reverse, shape, data.dtype, counter) timage = partial(image_type, histogram=False) class OverrideDict(dict): def to_dict(self): # return {k: v for k, v in self.items()} return dict(self) def get_fn(self, index): return self.__getitem__(index) class CalcDict(OverrideDict): def operate(self, other, op, right=False): if not hasattr(other, "__iter__"): other = [other] * len(self) if isinstance(other, dict): other = other.values() it = iter(other) if right: fn = lambda x, y: op(y, x) else: fn = op return CalcDict({k: fn(v, next(it)) for k, v in self.items()}) def __add__(self, other): return self.operate(other, operator.add) def __sub__(self, other): return self.operate(other, operator.sub) def __mul__(self, other): return self.operate(other, operator.mul) def __truediv__(self, other): return self.operate(other, operator.truediv) def __radd__(self, other): return self.operate(other, operator.add, right=True) def __rsub__(self, other): return self.operate(other, operator.sub, right=True) def __rmul__(self, other): return self.operate(other, operator.mul, right=True) def __rtruediv__(self, other): return self.operate(other, operator.truediv, right=True) def dict_from_list2(keys, values=None): if values is None: values = list(range(len(keys))) return dict(zip(keys[:len(lw(values))], lw(values)))