Spaces:
Sleeping
Sleeping
File size: 1,379 Bytes
81ecb2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import numpy as np
from .typing import *
# torch / numpy math utils
def dot(x: Union[Tensor, ndarray], y: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]:
"""dot product (along the last dim).
Args:
x (Union[Tensor, ndarray]): x, [..., C]
y (Union[Tensor, ndarray]): y, [..., C]
Returns:
Union[Tensor, ndarray]: x dot y, [..., 1]
"""
if isinstance(x, np.ndarray):
return np.sum(x * y, -1, keepdims=True)
else:
return torch.sum(x * y, -1, keepdim=True)
def length(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]:
"""length of an array (along the last dim).
Args:
x (Union[Tensor, ndarray]): x, [..., C]
eps (float, optional): eps. Defaults to 1e-20.
Returns:
Union[Tensor, ndarray]: length, [..., 1]
"""
if isinstance(x, np.ndarray):
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
else:
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
def safe_normalize(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]:
"""normalize an array (along the last dim).
Args:
x (Union[Tensor, ndarray]): x, [..., C]
eps (float, optional): eps. Defaults to 1e-20.
Returns:
Union[Tensor, ndarray]: normalized x, [..., C]
"""
return x / length(x, eps) |