Valeriy Sinyukov commited on
Commit
351e3b7
·
1 Parent(s): eafc49a

Add module which gathers all models

Browse files

This module collects models dict, which maps model name to function loading this model.

src/category_classification/__init__.py ADDED
File without changes
src/category_classification/models/__init__.py ADDED
File without changes
src/category_classification/models/models.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import os
3
+ import sys
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ def import_model_module(file_path: os.PathLike):
8
+ module_name = str(Path(file_path).relative_to(os.getcwd())).replace(os.path.sep, ".")
9
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
10
+ module = importlib.util.module_from_spec(spec)
11
+ sys.modules[module_name] = module
12
+ spec.loader.exec_module(module)
13
+ return module
14
+
15
+
16
+ models = {}
17
+
18
+ file_dir = Path(__file__).parents[0]
19
+
20
+ for path in file_dir.glob("*"):
21
+ if path.is_dir():
22
+ model_file_path = path / "model.py"
23
+ if not model_file_path.exists():
24
+ continue
25
+ module = import_model_module(model_file_path)
26
+ name_key = "name"
27
+ get_model_key = "get_model"
28
+ name = getattr(module, name_key, None)
29
+ get_model = getattr(module, get_model_key, None)
30
+
31
+ def check_attr_exists(attr_name, attr):
32
+ if attr is None:
33
+ warnings.warn(
34
+ f"Module {model_file_path} should define attribute '{attr_name}'"
35
+ )
36
+ return False
37
+ return True
38
+
39
+ def check_attr_type(attr_name, attr, type):
40
+ if isinstance(attr, type):
41
+ return True
42
+ warnings.warn(
43
+ f"'{attr_name}' should be of type {type}, but it is of type {type(attr)}"
44
+ )
45
+ return False
46
+
47
+ def check_attr_callable(attr_name, attr):
48
+ if callable(attr):
49
+ return True
50
+ warnings.warn(f"'{attr_name}' should be callable")
51
+ return False
52
+
53
+ if not check_attr_exists(name_key, name):
54
+ continue
55
+ if not check_attr_exists(get_model_key, get_model):
56
+ continue
57
+ if not check_attr_type(name_key, name, str):
58
+ continue
59
+ if not check_attr_callable(get_model_key, get_model):
60
+ continue
61
+
62
+ models[name] = get_model
63
+
64
+
65
+ def get_model(name: str):
66
+ if name not in models:
67
+ raise KeyError(f"No model with name {name}")
68
+ return models[name]()
69
+
70
+
71
+ def get_all_model_names():
72
+ return list(models.keys())