Spaces:
Running
on
L4
Running
on
L4
import torch | |
import numpy as np | |
from .typing import * | |
# torch / numpy math utils | |
def dot(x: Union[Tensor, ndarray], y: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]: | |
"""dot product (along the last dim). | |
Args: | |
x (Union[Tensor, ndarray]): x, [..., C] | |
y (Union[Tensor, ndarray]): y, [..., C] | |
Returns: | |
Union[Tensor, ndarray]: x dot y, [..., 1] | |
""" | |
if isinstance(x, np.ndarray): | |
return np.sum(x * y, -1, keepdims=True) | |
else: | |
return torch.sum(x * y, -1, keepdim=True) | |
def length(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]: | |
"""length of an array (along the last dim). | |
Args: | |
x (Union[Tensor, ndarray]): x, [..., C] | |
eps (float, optional): eps. Defaults to 1e-20. | |
Returns: | |
Union[Tensor, ndarray]: length, [..., 1] | |
""" | |
if isinstance(x, np.ndarray): | |
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) | |
else: | |
return torch.sqrt(torch.clamp(dot(x, x), min=eps)) | |
def safe_normalize(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]: | |
"""normalize an array (along the last dim). | |
Args: | |
x (Union[Tensor, ndarray]): x, [..., C] | |
eps (float, optional): eps. Defaults to 1e-20. | |
Returns: | |
Union[Tensor, ndarray]: normalized x, [..., C] | |
""" | |
return x / length(x, eps) |