Spaces:
Running
Running
r""" | |
The following constraints are implemented: | |
- ``constraints.boolean`` | |
- ``constraints.cat`` | |
- ``constraints.corr_cholesky`` | |
- ``constraints.dependent`` | |
- ``constraints.greater_than(lower_bound)`` | |
- ``constraints.greater_than_eq(lower_bound)`` | |
- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` | |
- ``constraints.integer_interval(lower_bound, upper_bound)`` | |
- ``constraints.interval(lower_bound, upper_bound)`` | |
- ``constraints.less_than(upper_bound)`` | |
- ``constraints.lower_cholesky`` | |
- ``constraints.lower_triangular`` | |
- ``constraints.multinomial`` | |
- ``constraints.nonnegative`` | |
- ``constraints.nonnegative_integer`` | |
- ``constraints.one_hot`` | |
- ``constraints.positive_integer`` | |
- ``constraints.positive`` | |
- ``constraints.positive_semidefinite`` | |
- ``constraints.positive_definite`` | |
- ``constraints.real_vector`` | |
- ``constraints.real`` | |
- ``constraints.simplex`` | |
- ``constraints.symmetric`` | |
- ``constraints.stack`` | |
- ``constraints.square`` | |
- ``constraints.symmetric`` | |
- ``constraints.unit_interval`` | |
""" | |
import torch | |
__all__ = [ | |
"Constraint", | |
"boolean", | |
"cat", | |
"corr_cholesky", | |
"dependent", | |
"dependent_property", | |
"greater_than", | |
"greater_than_eq", | |
"independent", | |
"integer_interval", | |
"interval", | |
"half_open_interval", | |
"is_dependent", | |
"less_than", | |
"lower_cholesky", | |
"lower_triangular", | |
"multinomial", | |
"nonnegative", | |
"nonnegative_integer", | |
"one_hot", | |
"positive", | |
"positive_semidefinite", | |
"positive_definite", | |
"positive_integer", | |
"real", | |
"real_vector", | |
"simplex", | |
"square", | |
"stack", | |
"symmetric", | |
"unit_interval", | |
] | |
class Constraint: | |
""" | |
Abstract base class for constraints. | |
A constraint object represents a region over which a variable is valid, | |
e.g. within which a variable can be optimized. | |
Attributes: | |
is_discrete (bool): Whether constrained space is discrete. | |
Defaults to False. | |
event_dim (int): Number of rightmost dimensions that together define | |
an event. The :meth:`check` method will remove this many dimensions | |
when computing validity. | |
""" | |
is_discrete = False # Default to continuous. | |
event_dim = 0 # Default to univariate. | |
def check(self, value): | |
""" | |
Returns a byte tensor of ``sample_shape + batch_shape`` indicating | |
whether each event in value satisfies this constraint. | |
""" | |
raise NotImplementedError | |
def __repr__(self): | |
return self.__class__.__name__[1:] + "()" | |
class _Dependent(Constraint): | |
""" | |
Placeholder for variables whose support depends on other variables. | |
These variables obey no simple coordinate-wise constraints. | |
Args: | |
is_discrete (bool): Optional value of ``.is_discrete`` in case this | |
can be computed statically. If not provided, access to the | |
``.is_discrete`` attribute will raise a NotImplementedError. | |
event_dim (int): Optional value of ``.event_dim`` in case this | |
can be computed statically. If not provided, access to the | |
``.event_dim`` attribute will raise a NotImplementedError. | |
""" | |
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): | |
self._is_discrete = is_discrete | |
self._event_dim = event_dim | |
super().__init__() | |
def is_discrete(self): | |
if self._is_discrete is NotImplemented: | |
raise NotImplementedError(".is_discrete cannot be determined statically") | |
return self._is_discrete | |
def event_dim(self): | |
if self._event_dim is NotImplemented: | |
raise NotImplementedError(".event_dim cannot be determined statically") | |
return self._event_dim | |
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): | |
""" | |
Support for syntax to customize static attributes:: | |
constraints.dependent(is_discrete=True, event_dim=1) | |
""" | |
if is_discrete is NotImplemented: | |
is_discrete = self._is_discrete | |
if event_dim is NotImplemented: | |
event_dim = self._event_dim | |
return _Dependent(is_discrete=is_discrete, event_dim=event_dim) | |
def check(self, x): | |
raise ValueError("Cannot determine validity of dependent constraint") | |
def is_dependent(constraint): | |
return isinstance(constraint, _Dependent) | |
class _DependentProperty(property, _Dependent): | |
""" | |
Decorator that extends @property to act like a `Dependent` constraint when | |
called on a class and act like a property when called on an object. | |
Example:: | |
class Uniform(Distribution): | |
def __init__(self, low, high): | |
self.low = low | |
self.high = high | |
@constraints.dependent_property(is_discrete=False, event_dim=0) | |
def support(self): | |
return constraints.interval(self.low, self.high) | |
Args: | |
fn (Callable): The function to be decorated. | |
is_discrete (bool): Optional value of ``.is_discrete`` in case this | |
can be computed statically. If not provided, access to the | |
``.is_discrete`` attribute will raise a NotImplementedError. | |
event_dim (int): Optional value of ``.event_dim`` in case this | |
can be computed statically. If not provided, access to the | |
``.event_dim`` attribute will raise a NotImplementedError. | |
""" | |
def __init__( | |
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented | |
): | |
super().__init__(fn) | |
self._is_discrete = is_discrete | |
self._event_dim = event_dim | |
def __call__(self, fn): | |
""" | |
Support for syntax to customize static attributes:: | |
@constraints.dependent_property(is_discrete=True, event_dim=1) | |
def support(self): | |
... | |
""" | |
return _DependentProperty( | |
fn, is_discrete=self._is_discrete, event_dim=self._event_dim | |
) | |
class _IndependentConstraint(Constraint): | |
""" | |
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many | |
dims in :meth:`check`, so that an event is valid only if all its | |
independent entries are valid. | |
""" | |
def __init__(self, base_constraint, reinterpreted_batch_ndims): | |
assert isinstance(base_constraint, Constraint) | |
assert isinstance(reinterpreted_batch_ndims, int) | |
assert reinterpreted_batch_ndims >= 0 | |
self.base_constraint = base_constraint | |
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims | |
super().__init__() | |
def is_discrete(self): | |
return self.base_constraint.is_discrete | |
def event_dim(self): | |
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims | |
def check(self, value): | |
result = self.base_constraint.check(value) | |
if result.dim() < self.reinterpreted_batch_ndims: | |
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims | |
raise ValueError( | |
f"Expected value.dim() >= {expected} but got {value.dim()}" | |
) | |
result = result.reshape( | |
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) | |
) | |
result = result.all(-1) | |
return result | |
def __repr__(self): | |
return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})" | |
class _Boolean(Constraint): | |
""" | |
Constrain to the two values `{0, 1}`. | |
""" | |
is_discrete = True | |
def check(self, value): | |
return (value == 0) | (value == 1) | |
class _OneHot(Constraint): | |
""" | |
Constrain to one-hot vectors. | |
""" | |
is_discrete = True | |
event_dim = 1 | |
def check(self, value): | |
is_boolean = (value == 0) | (value == 1) | |
is_normalized = value.sum(-1).eq(1) | |
return is_boolean.all(-1) & is_normalized | |
class _IntegerInterval(Constraint): | |
""" | |
Constrain to an integer interval `[lower_bound, upper_bound]`. | |
""" | |
is_discrete = True | |
def __init__(self, lower_bound, upper_bound): | |
self.lower_bound = lower_bound | |
self.upper_bound = upper_bound | |
super().__init__() | |
def check(self, value): | |
return ( | |
(value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) | |
) | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += ( | |
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" | |
) | |
return fmt_string | |
class _IntegerLessThan(Constraint): | |
""" | |
Constrain to an integer interval `(-inf, upper_bound]`. | |
""" | |
is_discrete = True | |
def __init__(self, upper_bound): | |
self.upper_bound = upper_bound | |
super().__init__() | |
def check(self, value): | |
return (value % 1 == 0) & (value <= self.upper_bound) | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += f"(upper_bound={self.upper_bound})" | |
return fmt_string | |
class _IntegerGreaterThan(Constraint): | |
""" | |
Constrain to an integer interval `[lower_bound, inf)`. | |
""" | |
is_discrete = True | |
def __init__(self, lower_bound): | |
self.lower_bound = lower_bound | |
super().__init__() | |
def check(self, value): | |
return (value % 1 == 0) & (value >= self.lower_bound) | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += f"(lower_bound={self.lower_bound})" | |
return fmt_string | |
class _Real(Constraint): | |
""" | |
Trivially constrain to the extended real line `[-inf, inf]`. | |
""" | |
def check(self, value): | |
return value == value # False for NANs. | |
class _GreaterThan(Constraint): | |
""" | |
Constrain to a real half line `(lower_bound, inf]`. | |
""" | |
def __init__(self, lower_bound): | |
self.lower_bound = lower_bound | |
super().__init__() | |
def check(self, value): | |
return self.lower_bound < value | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += f"(lower_bound={self.lower_bound})" | |
return fmt_string | |
class _GreaterThanEq(Constraint): | |
""" | |
Constrain to a real half line `[lower_bound, inf)`. | |
""" | |
def __init__(self, lower_bound): | |
self.lower_bound = lower_bound | |
super().__init__() | |
def check(self, value): | |
return self.lower_bound <= value | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += f"(lower_bound={self.lower_bound})" | |
return fmt_string | |
class _LessThan(Constraint): | |
""" | |
Constrain to a real half line `[-inf, upper_bound)`. | |
""" | |
def __init__(self, upper_bound): | |
self.upper_bound = upper_bound | |
super().__init__() | |
def check(self, value): | |
return value < self.upper_bound | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += f"(upper_bound={self.upper_bound})" | |
return fmt_string | |
class _Interval(Constraint): | |
""" | |
Constrain to a real interval `[lower_bound, upper_bound]`. | |
""" | |
def __init__(self, lower_bound, upper_bound): | |
self.lower_bound = lower_bound | |
self.upper_bound = upper_bound | |
super().__init__() | |
def check(self, value): | |
return (self.lower_bound <= value) & (value <= self.upper_bound) | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += ( | |
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" | |
) | |
return fmt_string | |
class _HalfOpenInterval(Constraint): | |
""" | |
Constrain to a real interval `[lower_bound, upper_bound)`. | |
""" | |
def __init__(self, lower_bound, upper_bound): | |
self.lower_bound = lower_bound | |
self.upper_bound = upper_bound | |
super().__init__() | |
def check(self, value): | |
return (self.lower_bound <= value) & (value < self.upper_bound) | |
def __repr__(self): | |
fmt_string = self.__class__.__name__[1:] | |
fmt_string += ( | |
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" | |
) | |
return fmt_string | |
class _Simplex(Constraint): | |
""" | |
Constrain to the unit simplex in the innermost (rightmost) dimension. | |
Specifically: `x >= 0` and `x.sum(-1) == 1`. | |
""" | |
event_dim = 1 | |
def check(self, value): | |
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) | |
class _Multinomial(Constraint): | |
""" | |
Constrain to nonnegative integer values summing to at most an upper bound. | |
Note due to limitations of the Multinomial distribution, this currently | |
checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future | |
this may be strengthened to ``value.sum(-1) == upper_bound``. | |
""" | |
is_discrete = True | |
event_dim = 1 | |
def __init__(self, upper_bound): | |
self.upper_bound = upper_bound | |
def check(self, x): | |
return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) | |
class _LowerTriangular(Constraint): | |
""" | |
Constrain to lower-triangular square matrices. | |
""" | |
event_dim = 2 | |
def check(self, value): | |
value_tril = value.tril() | |
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] | |
class _LowerCholesky(Constraint): | |
""" | |
Constrain to lower-triangular square matrices with positive diagonals. | |
""" | |
event_dim = 2 | |
def check(self, value): | |
value_tril = value.tril() | |
lower_triangular = ( | |
(value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] | |
) | |
positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] | |
return lower_triangular & positive_diagonal | |
class _CorrCholesky(Constraint): | |
""" | |
Constrain to lower-triangular square matrices with positive diagonals and each | |
row vector being of unit length. | |
""" | |
event_dim = 2 | |
def check(self, value): | |
tol = ( | |
torch.finfo(value.dtype).eps * value.size(-1) * 10 | |
) # 10 is an adjustable fudge factor | |
row_norm = torch.linalg.norm(value.detach(), dim=-1) | |
unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) | |
return _LowerCholesky().check(value) & unit_row_norm | |
class _Square(Constraint): | |
""" | |
Constrain to square matrices. | |
""" | |
event_dim = 2 | |
def check(self, value): | |
return torch.full( | |
size=value.shape[:-2], | |
fill_value=(value.shape[-2] == value.shape[-1]), | |
dtype=torch.bool, | |
device=value.device, | |
) | |
class _Symmetric(_Square): | |
""" | |
Constrain to Symmetric square matrices. | |
""" | |
def check(self, value): | |
square_check = super().check(value) | |
if not square_check.all(): | |
return square_check | |
return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1) | |
class _PositiveSemidefinite(_Symmetric): | |
""" | |
Constrain to positive-semidefinite matrices. | |
""" | |
def check(self, value): | |
sym_check = super().check(value) | |
if not sym_check.all(): | |
return sym_check | |
return torch.linalg.eigvalsh(value).ge(0).all(-1) | |
class _PositiveDefinite(_Symmetric): | |
""" | |
Constrain to positive-definite matrices. | |
""" | |
def check(self, value): | |
sym_check = super().check(value) | |
if not sym_check.all(): | |
return sym_check | |
return torch.linalg.cholesky_ex(value).info.eq(0) | |
class _Cat(Constraint): | |
""" | |
Constraint functor that applies a sequence of constraints | |
`cseq` at the submatrices at dimension `dim`, | |
each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. | |
""" | |
def __init__(self, cseq, dim=0, lengths=None): | |
assert all(isinstance(c, Constraint) for c in cseq) | |
self.cseq = list(cseq) | |
if lengths is None: | |
lengths = [1] * len(self.cseq) | |
self.lengths = list(lengths) | |
assert len(self.lengths) == len(self.cseq) | |
self.dim = dim | |
super().__init__() | |
def is_discrete(self): | |
return any(c.is_discrete for c in self.cseq) | |
def event_dim(self): | |
return max(c.event_dim for c in self.cseq) | |
def check(self, value): | |
assert -value.dim() <= self.dim < value.dim() | |
checks = [] | |
start = 0 | |
for constr, length in zip(self.cseq, self.lengths): | |
v = value.narrow(self.dim, start, length) | |
checks.append(constr.check(v)) | |
start = start + length # avoid += for jit compat | |
return torch.cat(checks, self.dim) | |
class _Stack(Constraint): | |
""" | |
Constraint functor that applies a sequence of constraints | |
`cseq` at the submatrices at dimension `dim`, | |
in a way compatible with :func:`torch.stack`. | |
""" | |
def __init__(self, cseq, dim=0): | |
assert all(isinstance(c, Constraint) for c in cseq) | |
self.cseq = list(cseq) | |
self.dim = dim | |
super().__init__() | |
def is_discrete(self): | |
return any(c.is_discrete for c in self.cseq) | |
def event_dim(self): | |
dim = max(c.event_dim for c in self.cseq) | |
if self.dim + dim < 0: | |
dim += 1 | |
return dim | |
def check(self, value): | |
assert -value.dim() <= self.dim < value.dim() | |
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] | |
return torch.stack( | |
[constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim | |
) | |
# Public interface. | |
dependent = _Dependent() | |
dependent_property = _DependentProperty | |
independent = _IndependentConstraint | |
boolean = _Boolean() | |
one_hot = _OneHot() | |
nonnegative_integer = _IntegerGreaterThan(0) | |
positive_integer = _IntegerGreaterThan(1) | |
integer_interval = _IntegerInterval | |
real = _Real() | |
real_vector = independent(real, 1) | |
positive = _GreaterThan(0.0) | |
nonnegative = _GreaterThanEq(0.0) | |
greater_than = _GreaterThan | |
greater_than_eq = _GreaterThanEq | |
less_than = _LessThan | |
multinomial = _Multinomial | |
unit_interval = _Interval(0.0, 1.0) | |
interval = _Interval | |
half_open_interval = _HalfOpenInterval | |
simplex = _Simplex() | |
lower_triangular = _LowerTriangular() | |
lower_cholesky = _LowerCholesky() | |
corr_cholesky = _CorrCholesky() | |
square = _Square() | |
symmetric = _Symmetric() | |
positive_semidefinite = _PositiveSemidefinite() | |
positive_definite = _PositiveDefinite() | |
cat = _Cat | |
stack = _Stack | |