Spaces:
Runtime error
Runtime error
File size: 5,494 Bytes
153628e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import logging
from typing import Any, List, Optional, Tuple, Union
import torch
from torch import nn
from doctr.utils.data import download_from_url
__all__ = [
"load_pretrained_params",
"conv_sequence_pt",
"set_device_and_dtype",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_float32",
]
def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
return x.clone().detach()
def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x
def load_pretrained_params(
model: nn.Module,
url: Optional[str] = None,
hash_prefix: Optional[str] = None,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Load a set of parameters onto a model
>>> from doctr.models import load_pretrained_params
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
Args:
----
model: the PyTorch model to be loaded
url: URL of the zipped set of parameters
hash_prefix: first characters of SHA256 expected hash
ignore_keys: list of weights to be ignored from the state_dict
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
"""
if url is None:
logging.warning("Invalid model URL, using default initialization.")
else:
archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
# Read state_dict
state_dict = torch.load(archive_path, map_location="cpu")
# Remove weights from the state_dict
if ignore_keys is not None and len(ignore_keys) > 0:
for key in ignore_keys:
state_dict.pop(key)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0:
raise ValueError("unable to load state_dict, due to non-matching keys.")
else:
# Load weights
model.load_state_dict(state_dict)
def conv_sequence_pt(
in_channels: int,
out_channels: int,
relu: bool = False,
bn: bool = False,
**kwargs: Any,
) -> List[nn.Module]:
"""Builds a convolutional-based layer sequence
>>> from torch.nn import Sequential
>>> from doctr.models import conv_sequence
>>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
Args:
----
in_channels: number of input channels
out_channels: number of output channels
relu: whether ReLU should be used
bn: should a batch normalization layer be added
**kwargs: additional arguments to be passed to the convolutional layer
Returns:
-------
list of layers
"""
# No bias before Batch norm
kwargs["bias"] = kwargs.get("bias", not bn)
# Add activation directly to the conv if there is no BN
conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
if bn:
conv_seq.append(nn.BatchNorm2d(out_channels))
if relu:
conv_seq.append(nn.ReLU(inplace=True))
return conv_seq
def set_device_and_dtype(
model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype
) -> Tuple[Any, List[torch.Tensor]]:
"""Set the device and dtype of a model and its batches
>>> import torch
>>> from torch import nn
>>> from doctr.models.utils import set_device_and_dtype
>>> model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4))
>>> batches = [torch.rand(8) for _ in range(2)]
>>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
Args:
----
model: the model to be set
batches: the batches to be set
device: the device to be used
dtype: the dtype to be used
Returns:
-------
the model and batches set
"""
return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.Tensor, **kwargs: Any) -> str:
"""Export model to ONNX format.
>>> import torch
>>> from doctr.models.classification import resnet18
>>> from doctr.models.utils import export_model_to_onnx
>>> model = resnet18(pretrained=True)
>>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
Args:
----
model: the PyTorch model to be exported
model_name: the name for the exported model
dummy_input: the dummy input to the model
kwargs: additional arguments to be passed to torch.onnx.export
Returns:
-------
the path to the exported model
"""
torch.onnx.export(
model,
dummy_input,
f"{model_name}.onnx",
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}},
export_params=True,
verbose=False,
**kwargs,
)
logging.info(f"Model exported to {model_name}.onnx")
return f"{model_name}.onnx"
|