|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import nullcontext |
|
from typing import Sequence, Union |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.nn.parameter import Parameter |
|
|
|
|
|
def to_device(obj, device): |
|
"""Move tensor or dict of tensors to device""" |
|
if isinstance(obj, dict): |
|
for k, v in obj.items(): |
|
if isinstance(v, dict): |
|
to_device(v, device) |
|
elif isinstance(v, torch.Tensor): |
|
obj[k] = obj[k].to(device) |
|
elif isinstance(obj, torch.Tensor): |
|
obj = obj.to(device) |
|
else: |
|
raise Exception(f"type {type(obj)} not supported") |
|
return obj |
|
|
|
|
|
def cdist(a: torch.Tensor, b: torch.Tensor = None): |
|
|
|
|
|
|
|
return torch.cdist( |
|
a, |
|
b if b is not None else a, |
|
compute_mode="donot_use_mm_for_euclid_dist", |
|
) |
|
|
|
|
|
def map_values_to_list(data: dict, recursive: bool = True) -> dict: |
|
""" |
|
Convert values in a dictionary to lists. |
|
|
|
Args: |
|
data (dict): The dictionary whose values need to be converted. |
|
recursive (bool): Whether to recursively convert nested dictionaries. Defaults to True. |
|
|
|
Returns: |
|
dict: The dictionary with values converted to lists. |
|
""" |
|
for k, v in data.items(): |
|
if isinstance(v, torch.Tensor): |
|
if v.dtype == torch.bfloat16: |
|
v = v.float() |
|
data[k] = v.cpu().numpy().tolist() |
|
elif isinstance(v, np.ndarray): |
|
data[k] = v.tolist() |
|
elif isinstance(v, dict) and recursive: |
|
data[k] = map_values_to_list(v, recursive) |
|
return data |
|
|
|
|
|
def round_values(data: dict, recursive: bool = True) -> dict: |
|
""" |
|
Round the values in a dictionary to two decimal places. |
|
|
|
Args: |
|
data (dict): The dictionary whose values need to be rounded. |
|
recursive (bool): Whether to recursively round values in nested dictionaries. Defaults to True. |
|
|
|
Returns: |
|
dict: The dictionary with values rounded to two decimal places. |
|
""" |
|
for k, v in data.items(): |
|
if isinstance(v, torch.Tensor): |
|
if v.dtype == torch.bfloat16: |
|
v = v.float() |
|
data[k] = np.round(v.cpu().numpy(), 2) |
|
elif isinstance(v, np.ndarray): |
|
data[k] = np.round(v, 2) |
|
elif isinstance(v, list): |
|
data[k] = list(np.round(np.array(v), 2)) |
|
elif isinstance(v, dict) and recursive: |
|
data[k] = round_values(v, recursive) |
|
return data |
|
|
|
|
|
def autocasting_disable_decorator(disable_casting: bool): |
|
""" |
|
Decorator to disable autocasting for a function. |
|
|
|
Args: |
|
disable_casting (bool): If True, disables autocasting; otherwise, uses the default autocasting context. |
|
|
|
Returns: |
|
function: A decorator that wraps the function with the specified autocasting context. |
|
""" |
|
|
|
def func_wrapper(func): |
|
def new_func(*args, **kwargs): |
|
_amp_context = ( |
|
torch.autocast(device_type="cuda", enabled=False) |
|
if disable_casting |
|
else nullcontext() |
|
) |
|
dtype = torch.float32 if disable_casting else None |
|
with _amp_context: |
|
return func( |
|
*( |
|
v.to(dtype=dtype) if isinstance(v, torch.Tensor) else v |
|
for v in args |
|
), |
|
**{ |
|
k: v.to(dtype=dtype) if isinstance(v, torch.Tensor) else v |
|
for k, v in kwargs.items() |
|
}, |
|
) |
|
|
|
return new_func |
|
|
|
return func_wrapper |
|
|
|
|
|
def dict_to_tensor(feature_dict: dict) -> dict: |
|
""" |
|
Convert values in a dictionary to tensors and ensure they have the correct dtype. |
|
|
|
Args: |
|
feature_dict (dict): The dictionary whose values need to be converted to tensors. |
|
|
|
Returns: |
|
dict: The dictionary with values converted to tensors and adjusted to the correct dtype. |
|
""" |
|
for k, v in feature_dict.items(): |
|
if not isinstance(v, torch.Tensor): |
|
dtype = feature_dict[k].dtype |
|
feature_dict[k] = torch.tensor(v) |
|
|
|
if dtype in [np.int64, np.int32]: |
|
feature_dict[k] = feature_dict[k].to(torch.int64) |
|
elif dtype in [np.float32, np.float64]: |
|
feature_dict[k] = feature_dict[k].to(torch.float32) |
|
|
|
return feature_dict |
|
|