Spaces:
Runtime error
Runtime error
File size: 466 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.model import is_model_wrapper
def get_ori_model(model: nn.Module) -> nn.Module:
"""Get original model if the input model is a model wrapper.
Args:
model (nn.Module): A model may be a model wrapper.
Returns:
nn.Module: The model without model wrapper.
"""
if is_model_wrapper(model):
return model.module
else:
return model
|