Spaces:
Build error
Build error
import dataclasses | |
import inspect | |
import os | |
import re | |
import shutil | |
from functools import partial | |
from typing import Dict, Counter, Tuple | |
import operator | |
import numpy as np | |
import requests | |
from funcy import rcompose, identity | |
from loguru import logger | |
from tensorflow.keras import Model, Sequential | |
from tensorflow.keras.layers import Layer | |
from tensorflow.keras.losses import Loss | |
def dir_decorator(fn): | |
def deco(path, *args, **kwargs): | |
if path is not None: | |
folder = os.path.dirname(path) | |
if folder: | |
os.makedirs(folder, exist_ok=True) | |
return fn(path, *args, **kwargs) | |
return deco | |
def extract_args(args): | |
if il(args): | |
if il(args[0]): | |
return tuple(args[0]) | |
return tuple(args) | |
return tuple(lw(args)) | |
def comp(*args): | |
return rcompose(*extract_args(args)) | |
def is_model(obj): | |
return isinstance(obj, (Model, Layer, Loss)) | |
def list_dir(root): | |
for path in os.listdir(root): | |
yield os.path.join(root, path) | |
def is_int(obj, bool_=False): | |
if not bool_ and isinstance(obj, bool): | |
return False | |
return isinstance(obj, (int, np.integer)) | |
# return issubclass(type(obj),int) | |
def is_float(obj): | |
return isinstance(obj, (float, np.float)) | |
def is_num(obj, bool_=False): | |
return is_int(obj, bool_) or is_float(obj) | |
def get_args(obj): | |
try: | |
return list(inspect.signature(obj).parameters.keys()) | |
except Exception as e: | |
logger.error(f"{e}. Returning 1.") | |
return ["arg"] | |
def il(data): | |
return isinstance(data, (list, tuple)) | |
def lw(data, none_empty=True, convert_tuple=True): | |
if isinstance(data, list): | |
return data | |
elif isinstance(data, tuple) and convert_tuple: | |
return list(data) | |
if none_empty and data is None: | |
return [] | |
return [data] | |
def dict_from_list(keys=None, values=None) -> Dict: | |
if values is None: | |
values = list(range(len(lw(keys)))) | |
elif callable(values): | |
values = [values(k) for k in keys] | |
if keys is None: | |
keys = list(range(len(lw(values)))) | |
elif callable(keys): | |
keys = [keys(v) for v in values] | |
if not isinstance(values, dict): | |
values = dict(zip(keys[:len(lw(values))], lw(values))) | |
return values | |
def call(fn, x=None): | |
if len(get_args(fn)) > 0: | |
if isinstance(x, dict): | |
return fn(**x) | |
elif il(x): | |
return fn(*x) | |
else: | |
return fn(x) | |
else: | |
return fn() | |
def get_str_comp(str_comp): | |
if callable(str_comp): | |
return str_comp | |
elif isinstance(str_comp, str): | |
if str_comp == "in": | |
return lambda x, y: x in y | |
elif str_comp in ["equal", "="]: | |
return lambda x, y: x == y | |
elif str_comp == "re": | |
return lambda x, y: re.match(x, y) | |
return lambda x, y: x == y | |
def test(filters, x=None, str_comp="in"): | |
str_comp = get_str_comp(str_comp) | |
result = [] | |
for f in lw(filters): | |
if isinstance(f, str): | |
result.append(str_comp(f, x)) | |
elif callable(f): | |
result.append(call(f, x)) | |
else: | |
result.append(f == x) | |
return result | |
test_all = comp([test, all]) | |
test_any = comp([test, any]) | |
def filter_keys(d, keys, reverse=False, str_comp="=", *args, **kwargs): | |
fn = test_any | |
if dataclasses.is_dataclass(keys): | |
keys = dataclasses.asdict(keys) | |
if isinstance(keys, dict): | |
keys = list(keys.keys()) | |
if dataclasses.is_dataclass(d): | |
d = dataclasses.asdict(d) | |
if reverse: | |
fn = comp(test_any, lambda x: not x) | |
return {k: v for k, v in d.items() if fn(keys, k, str_comp=str_comp, *args, **kwargs)} | |
def none(x): | |
return [] if x is None else x | |
def if_images(data): | |
if not hasattr(data, "shape"): | |
return False | |
shape = data.shape | |
return len(shape) > 3 or (len(shape) == 3 and shape[-1] > 4) | |
def download_file(url, path=None): | |
if path is None: | |
path = url.split("/")[-1] | |
r = requests.get(url, stream=True) | |
if r.status_code == 200: | |
r.raw.decode_content = True | |
with open(path, 'wb') as f: | |
shutil.copyfileobj(r.raw, f) | |
logger.info(f'File {url} sucessfully downloaded to {path} ') | |
else: | |
logger.info(f'File {url} couldn\'t be retreived') | |
class ImageType: | |
mode: str = "RGB" | |
range_: Tuple = (0, 255) | |
channel: bool = True | |
reverse: bool = False | |
shape: bool = (None, 224, 224, 3) | |
dtype: str = "uint8" | |
counter: Counter = None | |
def __eq__(self, other): | |
return filter_keys(dataclasses.asdict(self), "counter", reverse=True) == filter_keys(dataclasses.asdict(other), "counter", reverse=True) | |
def image_type(data, histogram=True): | |
data = np.asarray(data) | |
if not if_images(data): | |
return False | |
shape = data.shape | |
channel = True | |
range_ = data.min(), data.max() | |
counter = None | |
reverse = None | |
if shape[-1] == 3: | |
mode = "RGB" | |
if histogram: | |
counter = Counter(list(get_last_dim(data, 3).reshape(-1))) | |
reverse = counter.most_common(1)[0][0] == range_[1] | |
else: | |
mode = "L" | |
if shape[-1] != 1: | |
channel = False | |
if histogram: | |
counter = Counter(list(get_last_dim(data).reshape(-1))) | |
reverse = counter.most_common(1)[0][0] == range_[1] | |
else: | |
if histogram: | |
counter = Counter(list(get_last_dim(data, 3).reshape(-1))) | |
reverse = counter.most_common(1)[0][0] == range_[1] | |
return ImageType(mode, range_, channel, reverse, shape, data.dtype, counter) | |
timage = partial(image_type, histogram=False) | |
class OverrideDict(dict): | |
def to_dict(self): | |
# return {k: v for k, v in self.items()} | |
return dict(self) | |
def get_fn(self, index): | |
return self.__getitem__(index) | |
class CalcDict(OverrideDict): | |
def operate(self, other, op, right=False): | |
if not hasattr(other, "__iter__"): | |
other = [other] * len(self) | |
if isinstance(other, dict): | |
other = other.values() | |
it = iter(other) | |
if right: | |
fn = lambda x, y: op(y, x) | |
else: | |
fn = op | |
return CalcDict({k: fn(v, next(it)) for k, v in self.items()}) | |
def __add__(self, other): | |
return self.operate(other, operator.add) | |
def __sub__(self, other): | |
return self.operate(other, operator.sub) | |
def __mul__(self, other): | |
return self.operate(other, operator.mul) | |
def __truediv__(self, other): | |
return self.operate(other, operator.truediv) | |
def __radd__(self, other): | |
return self.operate(other, operator.add, right=True) | |
def __rsub__(self, other): | |
return self.operate(other, operator.sub, right=True) | |
def __rmul__(self, other): | |
return self.operate(other, operator.mul, right=True) | |
def __rtruediv__(self, other): | |
return self.operate(other, operator.truediv, right=True) | |
def dict_from_list2(keys, values=None): | |
if values is None: | |
values = list(range(len(keys))) | |
return dict(zip(keys[:len(lw(values))], lw(values))) | |