|
from typing import Callable, Dict, Any, List |
|
|
|
class ModelEntry: |
|
def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str]): |
|
self.model = model |
|
self.preprocess = preprocess |
|
self.postprocess = postprocess |
|
self.class_names = class_names |
|
|
|
MODEL_REGISTRY: Dict[str, ModelEntry] = {} |
|
|
|
def register_model(model_id: str, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str]): |
|
MODEL_REGISTRY[model_id] = ModelEntry(model, preprocess, postprocess, class_names) |