Spaces:
Running
Running
File size: 5,375 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
import re
import warnings
from typing import Tuple, List, Optional, Union, Dict, Any
SEMVER_VERSION_PATTERN = re.compile(
r"""
^
(?P<major>0|[1-9]\d*)
\.
(?P<minor>0|[1-9]\d*)
\.
(?P<patch>0|[1-9]\d*)
(?:-(?P<prerelease>
(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)
(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*
))?
(?:\+(?P<build>
[0-9a-zA-Z-]+
(?:\.[0-9a-zA-Z-]+)*
))?
$
""",
re.VERBOSE,
)
PEP_440_VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
def _validate_input(
tensors: List[torch.Tensor],
dim_range: Tuple[int, int] = (0, -1),
data_range: Tuple[float, float] = (0., -1.),
# size_dim_range: Tuple[float, float] = (0., -1.),
size_range: Optional[Tuple[int, int]] = None,
) -> None:
r"""Check that input(-s) satisfies the requirements
Args:
tensors: Tensors to check
dim_range: Allowed number of dimensions. (min, max)
data_range: Allowed range of values in tensors. (min, max)
size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
"""
if not __debug__:
return
x = tensors[0]
for t in tensors:
assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
if size_range is None:
assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
else:
assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
if dim_range[0] == dim_range[1]:
assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
elif dim_range[0] < dim_range[1]:
assert dim_range[0] <= t.dim() <= dim_range[1], \
f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
if data_range[0] < data_range[1]:
assert data_range[0] <= t.min(), \
f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
assert t.max() <= data_range[1], \
f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
r"""Reduce input in batch dimension if needed.
Args:
x: Tensor with shape (N, *).
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
"""
if reduction == 'none':
return x
elif reduction == 'mean':
return x.mean(dim=0)
elif reduction == 'sum':
return x.sum(dim=0)
else:
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
""" Parses valid Python versions according to Semver and PEP 440 specifications.
For more on Semver check: https://semver.org/
For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
Implementation is inspired by:
- https://github.com/python-semver
- https://github.com/pypa/packaging
Args:
version: unparsed information about the library of interest.
Returns:
parsed information about the library of interest.
"""
if isinstance(version, bytes):
version = version.decode("UTF-8")
elif not isinstance(version, str) and not isinstance(version, bytes):
raise TypeError(f"not expecting type {type(version)}")
# Semver processing
match = SEMVER_VERSION_PATTERN.match(version)
if match:
matched_version_parts: Dict[str, Any] = match.groupdict()
release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
return release
# PEP 440 processing
regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
match = regex.search(version)
if match is None:
warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
return tuple()
release = tuple(int(i) for i in match.group("release").split("."))
return release
|