raven / raven_utils /tools.py
Jakub Kwiatkowski
Refactor.
575eae1
raw
history blame
7.26 kB
import dataclasses
import inspect
import os
import re
import shutil
from functools import partial
from typing import Dict, Counter, Tuple
import operator
import numpy as np
import requests
from funcy import rcompose, identity
from loguru import logger
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Layer
from tensorflow.keras.losses import Loss
def dir_decorator(fn):
def deco(path, *args, **kwargs):
if path is not None:
folder = os.path.dirname(path)
if folder:
os.makedirs(folder, exist_ok=True)
return fn(path, *args, **kwargs)
return deco
def extract_args(args):
if il(args):
if il(args[0]):
return tuple(args[0])
return tuple(args)
return tuple(lw(args))
def comp(*args):
return rcompose(*extract_args(args))
def is_model(obj):
return isinstance(obj, (Model, Layer, Loss))
def list_dir(root):
for path in os.listdir(root):
yield os.path.join(root, path)
def is_int(obj, bool_=False):
if not bool_ and isinstance(obj, bool):
return False
return isinstance(obj, (int, np.integer))
# return issubclass(type(obj),int)
def is_float(obj):
return isinstance(obj, (float, np.float))
def is_num(obj, bool_=False):
return is_int(obj, bool_) or is_float(obj)
def get_args(obj):
try:
return list(inspect.signature(obj).parameters.keys())
except Exception as e:
logger.error(f"{e}. Returning 1.")
return ["arg"]
def il(data):
return isinstance(data, (list, tuple))
def lw(data, none_empty=True, convert_tuple=True):
if isinstance(data, list):
return data
elif isinstance(data, tuple) and convert_tuple:
return list(data)
if none_empty and data is None:
return []
return [data]
def dict_from_list(keys=None, values=None) -> Dict:
if values is None:
values = list(range(len(lw(keys))))
elif callable(values):
values = [values(k) for k in keys]
if keys is None:
keys = list(range(len(lw(values))))
elif callable(keys):
keys = [keys(v) for v in values]
if not isinstance(values, dict):
values = dict(zip(keys[:len(lw(values))], lw(values)))
return values
def call(fn, x=None):
if len(get_args(fn)) > 0:
if isinstance(x, dict):
return fn(**x)
elif il(x):
return fn(*x)
else:
return fn(x)
else:
return fn()
def get_str_comp(str_comp):
if callable(str_comp):
return str_comp
elif isinstance(str_comp, str):
if str_comp == "in":
return lambda x, y: x in y
elif str_comp in ["equal", "="]:
return lambda x, y: x == y
elif str_comp == "re":
return lambda x, y: re.match(x, y)
return lambda x, y: x == y
def test(filters, x=None, str_comp="in"):
str_comp = get_str_comp(str_comp)
result = []
for f in lw(filters):
if isinstance(f, str):
result.append(str_comp(f, x))
elif callable(f):
result.append(call(f, x))
else:
result.append(f == x)
return result
test_all = comp([test, all])
test_any = comp([test, any])
def filter_keys(d, keys, reverse=False, str_comp="=", *args, **kwargs):
fn = test_any
if dataclasses.is_dataclass(keys):
keys = dataclasses.asdict(keys)
if isinstance(keys, dict):
keys = list(keys.keys())
if dataclasses.is_dataclass(d):
d = dataclasses.asdict(d)
if reverse:
fn = comp(test_any, lambda x: not x)
return {k: v for k, v in d.items() if fn(keys, k, str_comp=str_comp, *args, **kwargs)}
def none(x):
return [] if x is None else x
def if_images(data):
if not hasattr(data, "shape"):
return False
shape = data.shape
return len(shape) > 3 or (len(shape) == 3 and shape[-1] > 4)
def download_file(url, path=None):
if path is None:
path = url.split("/")[-1]
r = requests.get(url, stream=True)
if r.status_code == 200:
r.raw.decode_content = True
with open(path, 'wb') as f:
shutil.copyfileobj(r.raw, f)
logger.info(f'File {url} sucessfully downloaded to {path} ')
else:
logger.info(f'File {url} couldn\'t be retreived')
@dataclasses.dataclass
class ImageType:
mode: str = "RGB"
range_: Tuple = (0, 255)
channel: bool = True
reverse: bool = False
shape: bool = (None, 224, 224, 3)
dtype: str = "uint8"
counter: Counter = None
def __eq__(self, other):
return filter_keys(dataclasses.asdict(self), "counter", reverse=True) == filter_keys(dataclasses.asdict(other), "counter", reverse=True)
def image_type(data, histogram=True):
data = np.asarray(data)
if not if_images(data):
return False
shape = data.shape
channel = True
range_ = data.min(), data.max()
counter = None
reverse = None
if shape[-1] == 3:
mode = "RGB"
if histogram:
counter = Counter(list(get_last_dim(data, 3).reshape(-1)))
reverse = counter.most_common(1)[0][0] == range_[1]
else:
mode = "L"
if shape[-1] != 1:
channel = False
if histogram:
counter = Counter(list(get_last_dim(data).reshape(-1)))
reverse = counter.most_common(1)[0][0] == range_[1]
else:
if histogram:
counter = Counter(list(get_last_dim(data, 3).reshape(-1)))
reverse = counter.most_common(1)[0][0] == range_[1]
return ImageType(mode, range_, channel, reverse, shape, data.dtype, counter)
timage = partial(image_type, histogram=False)
class OverrideDict(dict):
def to_dict(self):
# return {k: v for k, v in self.items()}
return dict(self)
def get_fn(self, index):
return self.__getitem__(index)
class CalcDict(OverrideDict):
def operate(self, other, op, right=False):
if not hasattr(other, "__iter__"):
other = [other] * len(self)
if isinstance(other, dict):
other = other.values()
it = iter(other)
if right:
fn = lambda x, y: op(y, x)
else:
fn = op
return CalcDict({k: fn(v, next(it)) for k, v in self.items()})
def __add__(self, other):
return self.operate(other, operator.add)
def __sub__(self, other):
return self.operate(other, operator.sub)
def __mul__(self, other):
return self.operate(other, operator.mul)
def __truediv__(self, other):
return self.operate(other, operator.truediv)
def __radd__(self, other):
return self.operate(other, operator.add, right=True)
def __rsub__(self, other):
return self.operate(other, operator.sub, right=True)
def __rmul__(self, other):
return self.operate(other, operator.mul, right=True)
def __rtruediv__(self, other):
return self.operate(other, operator.truediv, right=True)
def dict_from_list2(keys, values=None):
if values is None:
values = list(range(len(keys)))
return dict(zip(keys[:len(lw(values))], lw(values)))