FoldMark / protenix /utils /torch_utils.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import Sequence, Union
import numpy as np
import torch
from torch import nn
from torch.nn.parameter import Parameter
def to_device(obj, device):
"""Move tensor or dict of tensors to device"""
if isinstance(obj, dict):
for k, v in obj.items():
if isinstance(v, dict):
to_device(v, device)
elif isinstance(v, torch.Tensor):
obj[k] = obj[k].to(device)
elif isinstance(obj, torch.Tensor):
obj = obj.to(device)
else:
raise Exception(f"type {type(obj)} not supported")
return obj
def cdist(a: torch.Tensor, b: torch.Tensor = None):
# for tensor shape [1, 512 * 14, 3], donot_use_mm_for_euclid_dist mode costs 0.0489s,
# while use_mm_for_euclid_dist_if_necessary costs 0.0419s on cpu. On GPU there two costs
# will be neglectible. So there is no need to sacrifice accuracy for speed here.
return torch.cdist(
a,
b if b is not None else a,
compute_mode="donot_use_mm_for_euclid_dist",
)
def map_values_to_list(data: dict, recursive: bool = True) -> dict:
"""
Convert values in a dictionary to lists.
Args:
data (dict): The dictionary whose values need to be converted.
recursive (bool): Whether to recursively convert nested dictionaries. Defaults to True.
Returns:
dict: The dictionary with values converted to lists.
"""
for k, v in data.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.bfloat16:
v = v.float()
data[k] = v.cpu().numpy().tolist()
elif isinstance(v, np.ndarray):
data[k] = v.tolist()
elif isinstance(v, dict) and recursive:
data[k] = map_values_to_list(v, recursive)
return data
def round_values(data: dict, recursive: bool = True) -> dict:
"""
Round the values in a dictionary to two decimal places.
Args:
data (dict): The dictionary whose values need to be rounded.
recursive (bool): Whether to recursively round values in nested dictionaries. Defaults to True.
Returns:
dict: The dictionary with values rounded to two decimal places.
"""
for k, v in data.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.bfloat16:
v = v.float()
data[k] = np.round(v.cpu().numpy(), 2)
elif isinstance(v, np.ndarray):
data[k] = np.round(v, 2)
elif isinstance(v, list):
data[k] = list(np.round(np.array(v), 2))
elif isinstance(v, dict) and recursive:
data[k] = round_values(v, recursive)
return data
def autocasting_disable_decorator(disable_casting: bool):
"""
Decorator to disable autocasting for a function.
Args:
disable_casting (bool): If True, disables autocasting; otherwise, uses the default autocasting context.
Returns:
function: A decorator that wraps the function with the specified autocasting context.
"""
def func_wrapper(func):
def new_func(*args, **kwargs):
_amp_context = (
torch.autocast(device_type="cuda", enabled=False)
if disable_casting
else nullcontext()
)
dtype = torch.float32 if disable_casting else None
with _amp_context:
return func(
*(
v.to(dtype=dtype) if isinstance(v, torch.Tensor) else v
for v in args
),
**{
k: v.to(dtype=dtype) if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
},
)
return new_func
return func_wrapper
def dict_to_tensor(feature_dict: dict) -> dict:
"""
Convert values in a dictionary to tensors and ensure they have the correct dtype.
Args:
feature_dict (dict): The dictionary whose values need to be converted to tensors.
Returns:
dict: The dictionary with values converted to tensors and adjusted to the correct dtype.
"""
for k, v in feature_dict.items():
if not isinstance(v, torch.Tensor):
dtype = feature_dict[k].dtype
feature_dict[k] = torch.tensor(v)
if dtype in [np.int64, np.int32]:
feature_dict[k] = feature_dict[k].to(torch.int64)
elif dtype in [np.float32, np.float64]:
feature_dict[k] = feature_dict[k].to(torch.float32)
return feature_dict