File size: 1,075 Bytes
ac9c2b2 932e7b4 ac9c2b2 be96dd0 932e7b4 ac9c2b2 be96dd0 932e7b4 be96dd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from typing import Callable, Dict, Any, List, Optional
class ModelEntry:
def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str],
display_name: Optional[str] = None, contributor: Optional[str] = None, model_path: Optional[str] = None,
architecture: Optional[str] = None, dataset: Optional[str] = None):
self.model = model
self.preprocess = preprocess
self.postprocess = postprocess
self.class_names = class_names
self.display_name = display_name
self.contributor = contributor
self.model_path = model_path
self.architecture = architecture
self.dataset = dataset
MODEL_REGISTRY: Dict[str, ModelEntry] = {}
def register_model(model_id: str, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str], architecture: Optional[str] = None, dataset: Optional[str] = None):
MODEL_REGISTRY[model_id] = ModelEntry(model, preprocess, postprocess, class_names, architecture=architecture, dataset=dataset) |