File size: 1,075 Bytes
ac9c2b2
932e7b4
 
ac9c2b2
be96dd0
 
932e7b4
 
 
 
ac9c2b2
 
 
be96dd0
 
932e7b4
 
 
be96dd0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from typing import Callable, Dict, Any, List, Optional

class ModelEntry:
    def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str],
                 display_name: Optional[str] = None, contributor: Optional[str] = None, model_path: Optional[str] = None,
                 architecture: Optional[str] = None, dataset: Optional[str] = None):
        self.model = model
        self.preprocess = preprocess
        self.postprocess = postprocess
        self.class_names = class_names
        self.display_name = display_name
        self.contributor = contributor
        self.model_path = model_path
        self.architecture = architecture
        self.dataset = dataset

MODEL_REGISTRY: Dict[str, ModelEntry] = {}

def register_model(model_id: str, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str], architecture: Optional[str] = None, dataset: Optional[str] = None):
    MODEL_REGISTRY[model_id] = ModelEntry(model, preprocess, postprocess, class_names, architecture=architecture, dataset=dataset)