Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
import logging | |
from typing import Any, List, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from doctr.utils.data import download_from_url | |
__all__ = [ | |
"load_pretrained_params", | |
"conv_sequence_pt", | |
"set_device_and_dtype", | |
"export_model_to_onnx", | |
"_copy_tensor", | |
"_bf16_to_float32", | |
] | |
def _copy_tensor(x: torch.Tensor) -> torch.Tensor: | |
return x.clone().detach() | |
def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor: | |
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype | |
return x.float() if x.dtype == torch.bfloat16 else x | |
def load_pretrained_params( | |
model: nn.Module, | |
url: Optional[str] = None, | |
hash_prefix: Optional[str] = None, | |
ignore_keys: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> None: | |
"""Load a set of parameters onto a model | |
>>> from doctr.models import load_pretrained_params | |
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") | |
Args: | |
---- | |
model: the PyTorch model to be loaded | |
url: URL of the zipped set of parameters | |
hash_prefix: first characters of SHA256 expected hash | |
ignore_keys: list of weights to be ignored from the state_dict | |
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url` | |
""" | |
if url is None: | |
logging.warning("Invalid model URL, using default initialization.") | |
else: | |
archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) | |
# Read state_dict | |
state_dict = torch.load(archive_path, map_location="cpu") | |
# Remove weights from the state_dict | |
if ignore_keys is not None and len(ignore_keys) > 0: | |
for key in ignore_keys: | |
state_dict.pop(key) | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0: | |
raise ValueError("unable to load state_dict, due to non-matching keys.") | |
else: | |
# Load weights | |
model.load_state_dict(state_dict) | |
def conv_sequence_pt( | |
in_channels: int, | |
out_channels: int, | |
relu: bool = False, | |
bn: bool = False, | |
**kwargs: Any, | |
) -> List[nn.Module]: | |
"""Builds a convolutional-based layer sequence | |
>>> from torch.nn import Sequential | |
>>> from doctr.models import conv_sequence | |
>>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3)) | |
Args: | |
---- | |
in_channels: number of input channels | |
out_channels: number of output channels | |
relu: whether ReLU should be used | |
bn: should a batch normalization layer be added | |
**kwargs: additional arguments to be passed to the convolutional layer | |
Returns: | |
------- | |
list of layers | |
""" | |
# No bias before Batch norm | |
kwargs["bias"] = kwargs.get("bias", not bn) | |
# Add activation directly to the conv if there is no BN | |
conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)] | |
if bn: | |
conv_seq.append(nn.BatchNorm2d(out_channels)) | |
if relu: | |
conv_seq.append(nn.ReLU(inplace=True)) | |
return conv_seq | |
def set_device_and_dtype( | |
model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype | |
) -> Tuple[Any, List[torch.Tensor]]: | |
"""Set the device and dtype of a model and its batches | |
>>> import torch | |
>>> from torch import nn | |
>>> from doctr.models.utils import set_device_and_dtype | |
>>> model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) | |
>>> batches = [torch.rand(8) for _ in range(2)] | |
>>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16) | |
Args: | |
---- | |
model: the model to be set | |
batches: the batches to be set | |
device: the device to be used | |
dtype: the dtype to be used | |
Returns: | |
------- | |
the model and batches set | |
""" | |
return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches] | |
def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.Tensor, **kwargs: Any) -> str: | |
"""Export model to ONNX format. | |
>>> import torch | |
>>> from doctr.models.classification import resnet18 | |
>>> from doctr.models.utils import export_model_to_onnx | |
>>> model = resnet18(pretrained=True) | |
>>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32)) | |
Args: | |
---- | |
model: the PyTorch model to be exported | |
model_name: the name for the exported model | |
dummy_input: the dummy input to the model | |
kwargs: additional arguments to be passed to torch.onnx.export | |
Returns: | |
------- | |
the path to the exported model | |
""" | |
torch.onnx.export( | |
model, | |
dummy_input, | |
f"{model_name}.onnx", | |
input_names=["input"], | |
output_names=["logits"], | |
dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, | |
export_params=True, | |
verbose=False, | |
**kwargs, | |
) | |
logging.info(f"Model exported to {model_name}.onnx") | |
return f"{model_name}.onnx" | |