|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import random |
|
import traceback |
|
from copy import deepcopy |
|
from pathlib import Path |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from biotite.structure.atoms import AtomArray |
|
from ml_collections.config_dict import ConfigDict |
|
from torch.utils.data import Dataset |
|
|
|
from protenix.data.constants import EvaluationChainInterface |
|
from protenix.data.data_pipeline import DataPipeline |
|
from protenix.data.featurizer import Featurizer |
|
from protenix.data.msa_featurizer import MSAFeaturizer |
|
from protenix.data.tokenizer import TokenArray |
|
from protenix.data.utils import data_type_transform, make_dummy_feature |
|
from protenix.utils.cropping import CropData |
|
from protenix.utils.file_io import read_indices_csv |
|
from protenix.utils.logger import get_logger |
|
from protenix.utils.torch_utils import dict_to_tensor |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class BaseSingleDataset(Dataset): |
|
""" |
|
dataset for a single data source |
|
data = self.__item__(idx) |
|
return a dict of features and labels, the keys and the shape are defined in protenix.data.utils |
|
""" |
|
|
|
def __init__( |
|
self, |
|
mmcif_dir: Union[str, Path], |
|
bioassembly_dict_dir: Optional[Union[str, Path]], |
|
indices_fpath: Union[str, Path], |
|
cropping_configs: dict[str, Any], |
|
msa_featurizer: Optional[MSAFeaturizer] = None, |
|
template_featurizer: Optional[Any] = None, |
|
name: str = None, |
|
**kwargs, |
|
) -> None: |
|
super(BaseSingleDataset, self).__init__() |
|
|
|
|
|
self.mmcif_dir = mmcif_dir |
|
self.bioassembly_dict_dir = bioassembly_dict_dir |
|
self.indices_fpath = indices_fpath |
|
self.cropping_configs = cropping_configs |
|
self.name = name |
|
|
|
self.ref_pos_augment = kwargs.get("ref_pos_augment", True) |
|
self.lig_atom_rename = kwargs.get("lig_atom_rename", False) |
|
self.reassign_continuous_chain_ids = kwargs.get( |
|
"reassign_continuous_chain_ids", False |
|
) |
|
self.shuffle_mols = kwargs.get("shuffle_mols", False) |
|
self.shuffle_sym_ids = kwargs.get("shuffle_sym_ids", False) |
|
|
|
|
|
self.find_pocket = kwargs.get("find_pocket", False) |
|
self.find_all_pockets = kwargs.get("find_all_pockets", False) |
|
self.find_eval_chain_interface = kwargs.get("find_eval_chain_interface", False) |
|
self.group_by_pdb_id = kwargs.get("group_by_pdb_id", False) |
|
self.sort_by_n_token = kwargs.get("sort_by_n_token", False) |
|
|
|
|
|
self.random_sample_if_failed = kwargs.get("random_sample_if_failed", False) |
|
self.use_reference_chains_only = kwargs.get("use_reference_chains_only", False) |
|
self.is_distillation = kwargs.get("is_distillation", False) |
|
|
|
|
|
self.max_n_token = kwargs.get("max_n_token", -1) |
|
self.pdb_list = kwargs.get("pdb_list", None) |
|
if len(self.pdb_list) == 0: |
|
self.pdb_list = None |
|
|
|
self.exclusion_dict = kwargs.get("exclusion", {}) |
|
self.limits = kwargs.get( |
|
"limits", -1 |
|
) |
|
|
|
self.error_dir = kwargs.get("error_dir", None) |
|
if self.error_dir is not None: |
|
os.makedirs(self.error_dir, exist_ok=True) |
|
|
|
self.msa_featurizer = msa_featurizer |
|
self.template_featurizer = template_featurizer |
|
|
|
|
|
self.indices_list = self.read_indices_list(indices_fpath) |
|
|
|
@staticmethod |
|
def read_pdb_list(pdb_list: Union[list, str]) -> Optional[list]: |
|
""" |
|
Reads a list of PDB IDs from a file or directly from a list. |
|
|
|
Args: |
|
pdb_list: A list of PDB IDs or a file path containing PDB IDs. |
|
|
|
Returns: |
|
A list of PDB IDs if the input is valid, otherwise None. |
|
""" |
|
if pdb_list is None: |
|
return None |
|
|
|
if isinstance(pdb_list, list): |
|
return pdb_list |
|
|
|
with open(pdb_list, "r") as f: |
|
pdb_filter_list = [] |
|
for l in f.readlines(): |
|
l = l.strip() |
|
if l: |
|
pdb_filter_list.append(l) |
|
return pdb_filter_list |
|
|
|
def read_indices_list(self, indices_fpath: Union[str, Path]) -> pd.DataFrame: |
|
""" |
|
Reads and processes a list of indices from a CSV file. |
|
|
|
Args: |
|
indices_fpath: Path to the CSV file containing the indices. |
|
|
|
Returns: |
|
A DataFrame containing the processed indices. |
|
""" |
|
indices_list = read_indices_csv(indices_fpath) |
|
num_data = len(indices_list) |
|
logger.info(f"#Rows in indices list: {num_data}") |
|
|
|
if self.pdb_list is not None: |
|
pdb_filter_list = set(self.read_pdb_list(pdb_list=self.pdb_list)) |
|
indices_list = indices_list[indices_list["pdb_id"].isin(pdb_filter_list)] |
|
logger.info(f"[filtered by pdb_list] #Rows: {len(indices_list)}") |
|
|
|
|
|
if self.max_n_token > 0: |
|
valid_mask = indices_list["num_tokens"].astype(int) <= self.max_n_token |
|
removed_list = indices_list[~valid_mask] |
|
indices_list = indices_list[valid_mask] |
|
logger.info(f"[removed] #Rows: {len(removed_list)}") |
|
logger.info(f"[removed] #PDB: {removed_list['pdb_id'].nunique()}") |
|
logger.info( |
|
f"[filtered by n_token ({self.max_n_token})] #Rows: {len(indices_list)}" |
|
) |
|
|
|
|
|
for col_name, exclusion_list in self.exclusion_dict.items(): |
|
cols = col_name.split("|") |
|
exclusion_set = {tuple(excl.split("|")) for excl in exclusion_list} |
|
|
|
def is_valid(row): |
|
return tuple(row[col] for col in cols) not in exclusion_set |
|
|
|
valid_mask = indices_list.apply(is_valid, axis=1) |
|
indices_list = indices_list[valid_mask].reset_index(drop=True) |
|
logger.info( |
|
f"[Excluded by {col_name} -- {exclusion_list}] #Rows: {len(indices_list)}" |
|
) |
|
self.print_data_stats(indices_list) |
|
|
|
|
|
|
|
if self.group_by_pdb_id: |
|
indices_list = [ |
|
df.reset_index() for _, df in indices_list.groupby("pdb_id", sort=True) |
|
] |
|
|
|
if self.sort_by_n_token: |
|
|
|
if self.group_by_pdb_id: |
|
indices_list = sorted( |
|
indices_list, |
|
key=lambda df: int(df["num_tokens"].iloc[0]), |
|
reverse=True, |
|
) |
|
else: |
|
indices_list = indices_list.sort_values( |
|
by="num_tokens", key=lambda x: x.astype(int), ascending=False |
|
).reset_index(drop=True) |
|
|
|
if self.find_eval_chain_interface: |
|
|
|
if self.group_by_pdb_id: |
|
indices_list = [ |
|
df |
|
for df in indices_list |
|
if len( |
|
set(df["eval_type"].to_list()).intersection( |
|
set(EvaluationChainInterface) |
|
) |
|
) |
|
> 0 |
|
] |
|
else: |
|
indices_list = indices_list[ |
|
indices_list["eval_type"].apply( |
|
lambda x: x in EvaluationChainInterface |
|
) |
|
] |
|
if self.limits > 0 and len(indices_list) > self.limits: |
|
logger.info( |
|
f"Limit indices list size from {len(indices_list)} to {self.limits}" |
|
) |
|
indices_list = indices_list[: self.limits] |
|
return indices_list |
|
|
|
def print_data_stats(self, df: pd.DataFrame) -> None: |
|
""" |
|
Prints statistics about the dataset, including the distribution of molecular group types. |
|
|
|
Args: |
|
df: A DataFrame containing the indices list. |
|
""" |
|
if self.name: |
|
logger.info("-" * 10 + f" Dataset {self.name}" + "-" * 10) |
|
df["mol_group_type"] = df.apply( |
|
lambda row: "_".join( |
|
sorted( |
|
[ |
|
str(row["mol_1_type"]), |
|
str(row["mol_2_type"]).replace("nan", "intra"), |
|
] |
|
) |
|
), |
|
axis=1, |
|
) |
|
|
|
group_size_dict = dict(df["mol_group_type"].value_counts()) |
|
for i, n_i in group_size_dict.items(): |
|
logger.info(f"{i}: {n_i}/{len(df)}({round(n_i*100/len(df), 2)}%)") |
|
|
|
logger.info("-" * 30) |
|
if "cluster_id" in df.columns: |
|
n_cluster = df["cluster_id"].nunique() |
|
for i in group_size_dict: |
|
n_i = df[df["mol_group_type"] == i]["cluster_id"].nunique() |
|
logger.info(f"{i}: {n_i}/{n_cluster}({round(n_i*100/n_cluster, 2)}%)") |
|
logger.info("-" * 30) |
|
|
|
logger.info(f"Final pdb ids: {len(set(df.pdb_id.tolist()))}") |
|
logger.info("-" * 30) |
|
|
|
def __len__(self) -> int: |
|
return len(self.indices_list) |
|
|
|
def save_error_data(self, idx: int, error_message: str) -> None: |
|
""" |
|
Saves the error data for a specific index to a JSON file in the error directory. |
|
|
|
Args: |
|
idx: The index of the data sample that caused the error. |
|
error_message: The error message to be saved. |
|
""" |
|
if self.error_dir is not None: |
|
sample_indice = self._get_sample_indice(idx=idx) |
|
data = sample_indice.to_dict() |
|
data["error"] = error_message |
|
|
|
filename = f"{sample_indice.pdb_id}-{sample_indice.chain_1_id}-{sample_indice.chain_2_id}.json" |
|
fpath = os.path.join(self.error_dir, filename) |
|
if not os.path.exists(fpath): |
|
with open(fpath, "w") as f: |
|
json.dump(data, f) |
|
|
|
def __getitem__(self, idx: int): |
|
""" |
|
Retrieves a data sample by processing the given index. |
|
If an error occurs, it attempts to handle it by either saving the error data or randomly sampling another index. |
|
|
|
Args: |
|
idx: The index of the data sample to retrieve. |
|
|
|
Returns: |
|
A dictionary containing the processed data sample. |
|
""" |
|
|
|
for _ in range(10): |
|
try: |
|
data = self.process_one(idx) |
|
return data |
|
except Exception as e: |
|
error_message = f"{e} at idx {idx}:\n{traceback.format_exc()}" |
|
self.save_error_data(idx, error_message) |
|
|
|
if self.random_sample_if_failed: |
|
logger.exception(f"[skip data {idx}] {error_message}") |
|
|
|
idx = random.choice(range(len(self.indices_list))) |
|
continue |
|
else: |
|
raise Exception(e) |
|
return data |
|
|
|
def _get_bioassembly_data( |
|
self, idx: int |
|
) -> tuple[list[dict[str, Any]], dict[str, Any]]: |
|
sample_indice = self._get_sample_indice(idx=idx) |
|
if self.bioassembly_dict_dir is not None: |
|
bioassembly_dict_fpath = os.path.join( |
|
self.bioassembly_dict_dir, sample_indice.pdb_id + ".pkl.gz" |
|
) |
|
else: |
|
bioassembly_dict_fpath = None |
|
|
|
bioassembly_dict = DataPipeline.get_data_bioassembly( |
|
bioassembly_dict_fpath=bioassembly_dict_fpath |
|
) |
|
bioassembly_dict["pdb_id"] = sample_indice.pdb_id |
|
return sample_indice, bioassembly_dict, bioassembly_dict_fpath |
|
|
|
@staticmethod |
|
def _reassign_atom_array_chain_id(atom_array: AtomArray): |
|
""" |
|
In experiments conducted to observe overfitting effects using training sets, |
|
the pre-stored AtomArray in the training set may experience issues with discontinuous chain IDs due to filtering. |
|
Consequently, a temporary patch has been implemented to resolve this issue. |
|
|
|
e.g. 3x6u asym_id_int = [0, 1, 2, ... 18, 20] -> reassigned_asym_id_int [0, 1, 2, ..., 18, 19] |
|
""" |
|
|
|
def _get_contiguous_array(array): |
|
array_uniq = np.sort(np.unique(array)) |
|
map_dict = {i: idx for idx, i in enumerate(array_uniq)} |
|
new_array = np.vectorize(map_dict.get)(array) |
|
return new_array |
|
|
|
atom_array.asym_id_int = _get_contiguous_array(atom_array.asym_id_int) |
|
atom_array.entity_id_int = _get_contiguous_array(atom_array.entity_id_int) |
|
atom_array.sym_id_int = _get_contiguous_array(atom_array.sym_id_int) |
|
return atom_array |
|
|
|
@staticmethod |
|
def _shuffle_array_based_on_mol_id(token_array: TokenArray, atom_array: AtomArray): |
|
""" |
|
Shuffle both token_array and atom_array. |
|
Atoms/tokens with the same mol_id will be shuffled as a integrated component. |
|
""" |
|
|
|
|
|
centre_atom_indices = token_array.get_annotation("centre_atom_index") |
|
token_mol_id = atom_array[centre_atom_indices].mol_id |
|
|
|
|
|
shuffled_mol_ids = np.unique(token_mol_id).copy() |
|
np.random.shuffle(shuffled_mol_ids) |
|
|
|
|
|
original_token_indices = np.arange(len(token_mol_id)) |
|
shuffled_token_indices = [] |
|
for mol_id in shuffled_mol_ids: |
|
mol_token_indices = original_token_indices[token_mol_id == mol_id] |
|
shuffled_token_indices.append(mol_token_indices) |
|
shuffled_token_indices = np.concatenate(shuffled_token_indices) |
|
|
|
|
|
|
|
token_array, atom_array, _, _ = CropData.select_by_token_indices( |
|
token_array=token_array, |
|
atom_array=atom_array, |
|
selected_token_indices=shuffled_token_indices, |
|
) |
|
|
|
return token_array, atom_array |
|
|
|
@staticmethod |
|
def _assign_random_sym_id(atom_array: AtomArray): |
|
""" |
|
Assign random sym_id for chains of the same entity_id |
|
e.g. |
|
when entity_id = 0 |
|
sym_id_int = [0, 1, 2] -> random_sym_id_int = [2, 0, 1] |
|
when entity_id = 1 |
|
sym_id_int = [0, 1, 2, 3] -> random_sym_id_int = [3, 0, 1, 2] |
|
""" |
|
|
|
def _shuffle(x): |
|
x_unique = np.sort(np.unique(x)) |
|
x_shuffled = x_unique.copy() |
|
np.random.shuffle(x_shuffled) |
|
map_dict = dict(zip(x_unique, x_shuffled)) |
|
new_x = np.vectorize(map_dict.get)(x) |
|
return new_x.copy() |
|
|
|
for entity_id in np.unique(atom_array.label_entity_id): |
|
mask = atom_array.label_entity_id == entity_id |
|
atom_array.sym_id_int[mask] = _shuffle(atom_array.sym_id_int[mask]) |
|
return atom_array |
|
|
|
def process_one( |
|
self, idx: int, return_atom_token_array: bool = False |
|
) -> dict[str, dict]: |
|
""" |
|
Processes a single data sample by retrieving bioassembly data, applying various transformations, and cropping the data. |
|
It then extracts features and labels, and optionally returns the processed atom and token arrays. |
|
|
|
Args: |
|
idx: The index of the data sample to process. |
|
return_atom_token_array: Whether to return the processed atom and token arrays. |
|
|
|
Returns: |
|
A dict containing the input features, labels, basic_info and optionally the processed atom and token arrays. |
|
""" |
|
|
|
sample_indice, bioassembly_dict, bioassembly_dict_fpath = ( |
|
self._get_bioassembly_data(idx=idx) |
|
) |
|
|
|
if self.use_reference_chains_only: |
|
|
|
ref_chain_ids = [sample_indice.chain_1_id, sample_indice.chain_2_id] |
|
if sample_indice.type == "chain": |
|
ref_chain_ids.pop(-1) |
|
|
|
|
|
token_centre_atom_indices = bioassembly_dict["token_array"].get_annotation( |
|
"centre_atom_index" |
|
) |
|
token_chain_id = bioassembly_dict["atom_array"][ |
|
token_centre_atom_indices |
|
].chain_id |
|
is_ref_chain = np.isin(token_chain_id, ref_chain_ids) |
|
bioassembly_dict["token_array"], bioassembly_dict["atom_array"], _, _ = ( |
|
CropData.select_by_token_indices( |
|
token_array=bioassembly_dict["token_array"], |
|
atom_array=bioassembly_dict["atom_array"], |
|
selected_token_indices=np.arange(len(is_ref_chain))[is_ref_chain], |
|
) |
|
) |
|
|
|
if self.shuffle_mols: |
|
bioassembly_dict["token_array"], bioassembly_dict["atom_array"] = ( |
|
self._shuffle_array_based_on_mol_id( |
|
token_array=bioassembly_dict["token_array"], |
|
atom_array=bioassembly_dict["atom_array"], |
|
) |
|
) |
|
|
|
if self.shuffle_sym_ids: |
|
bioassembly_dict["atom_array"] = self._assign_random_sym_id( |
|
bioassembly_dict["atom_array"] |
|
) |
|
|
|
if self.reassign_continuous_chain_ids: |
|
bioassembly_dict["atom_array"] = self._reassign_atom_array_chain_id( |
|
bioassembly_dict["atom_array"] |
|
) |
|
|
|
|
|
( |
|
crop_method, |
|
cropped_token_array, |
|
cropped_atom_array, |
|
cropped_msa_features, |
|
cropped_template_features, |
|
reference_token_index, |
|
) = self.crop( |
|
sample_indice=sample_indice, |
|
bioassembly_dict=bioassembly_dict, |
|
**self.cropping_configs, |
|
) |
|
|
|
feat, label, label_full = self.get_feature_and_label( |
|
idx=idx, |
|
token_array=cropped_token_array, |
|
atom_array=cropped_atom_array, |
|
msa_features=cropped_msa_features, |
|
template_features=cropped_template_features, |
|
full_atom_array=bioassembly_dict["atom_array"], |
|
is_spatial_crop="spatial" in crop_method.lower(), |
|
) |
|
|
|
|
|
basic_info = { |
|
"pdb_id": ( |
|
bioassembly_dict["pdb_id"] |
|
if self.is_distillation is False |
|
else sample_indice["pdb_id"] |
|
), |
|
"N_asym": torch.tensor([len(torch.unique(feat["asym_id"]))]), |
|
"N_token": torch.tensor([feat["token_index"].shape[0]]), |
|
"N_atom": torch.tensor([feat["atom_to_token_idx"].shape[0]]), |
|
"N_msa": torch.tensor([feat["msa"].shape[0]]), |
|
"bioassembly_dict_fpath": bioassembly_dict_fpath, |
|
"N_msa_prot_pair": torch.tensor([feat["prot_pair_num_alignments"]]), |
|
"N_msa_prot_unpair": torch.tensor([feat["prot_unpair_num_alignments"]]), |
|
"N_msa_rna_pair": torch.tensor([feat["rna_pair_num_alignments"]]), |
|
"N_msa_rna_unpair": torch.tensor([feat["rna_unpair_num_alignments"]]), |
|
} |
|
|
|
for mol_type in ("protein", "ligand", "rna", "dna"): |
|
abbr = {"protein": "prot", "ligand": "lig"} |
|
abbr_type = abbr.get(mol_type, mol_type) |
|
mol_type_mask = feat[f"is_{mol_type}"].bool() |
|
n_atom = int(mol_type_mask.sum(dim=-1).item()) |
|
n_token = len(torch.unique(feat["atom_to_token_idx"][mol_type_mask])) |
|
basic_info[f"N_{abbr_type}_atom"] = torch.tensor([n_atom]) |
|
basic_info[f"N_{abbr_type}_token"] = torch.tensor([n_token]) |
|
|
|
|
|
asymn_id_to_chain_id = { |
|
atom.asym_id_int: atom.chain_id for atom in cropped_atom_array |
|
} |
|
chain_id_list = [ |
|
asymn_id_to_chain_id[asymn_id_int] |
|
for asymn_id_int in sorted(asymn_id_to_chain_id.keys()) |
|
] |
|
basic_info["chain_id"] = chain_id_list |
|
|
|
data = { |
|
"input_feature_dict": feat, |
|
"label_dict": label, |
|
"label_full_dict": label_full, |
|
"basic": basic_info, |
|
} |
|
|
|
if return_atom_token_array: |
|
data["cropped_atom_array"] = cropped_atom_array |
|
data["cropped_token_array"] = cropped_token_array |
|
return data |
|
|
|
def crop( |
|
self, |
|
sample_indice: pd.Series, |
|
bioassembly_dict: dict[str, Any], |
|
crop_size: int, |
|
method_weights: list[float], |
|
contiguous_crop_complete_lig: bool = True, |
|
spatial_crop_complete_lig: bool = True, |
|
drop_last: bool = True, |
|
remove_metal: bool = True, |
|
) -> tuple[str, TokenArray, AtomArray, dict[str, Any], dict[str, Any]]: |
|
""" |
|
Crops the bioassembly data based on the specified configurations. |
|
|
|
Returns: |
|
A tuple containing the cropping method, cropped token array, cropped atom array, |
|
cropped MSA features, and cropped template features. |
|
""" |
|
return DataPipeline.crop( |
|
one_sample=sample_indice, |
|
bioassembly_dict=bioassembly_dict, |
|
crop_size=crop_size, |
|
msa_featurizer=self.msa_featurizer, |
|
template_featurizer=self.template_featurizer, |
|
method_weights=method_weights, |
|
contiguous_crop_complete_lig=contiguous_crop_complete_lig, |
|
spatial_crop_complete_lig=spatial_crop_complete_lig, |
|
drop_last=drop_last, |
|
remove_metal=remove_metal, |
|
) |
|
|
|
def _get_sample_indice(self, idx: int) -> pd.Series: |
|
""" |
|
Retrieves the sample indice for a given index. If the dataset is grouped by PDB ID, it returns the first row of the PDB-idx. |
|
Otherwise, it returns the row at the specified index. |
|
|
|
Args: |
|
idx: The index of the data sample to retrieve. |
|
|
|
Returns: |
|
A pandas Series containing the sample indice. |
|
""" |
|
if self.group_by_pdb_id: |
|
|
|
sample_indice = self.indices_list[idx].iloc[0] |
|
else: |
|
sample_indice = self.indices_list.iloc[idx] |
|
return sample_indice |
|
|
|
def _get_eval_chain_interface_mask( |
|
self, idx: int, atom_array_chain_id: np.ndarray |
|
) -> tuple[np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]: |
|
""" |
|
Retrieves the evaluation chain/interface mask for a given index. |
|
|
|
Args: |
|
idx: The index of the data sample. |
|
atom_array_chain_id: An array containing the chain IDs of the atom array. |
|
|
|
Returns: |
|
A tuple containing the evaluation type, cluster ID, chain 1 mask, and chain 2 mask. |
|
""" |
|
if self.group_by_pdb_id: |
|
df = self.indices_list[idx] |
|
else: |
|
df = self.indices_list.iloc[idx : idx + 1] |
|
|
|
|
|
df = df[df["eval_type"].apply(lambda x: x in EvaluationChainInterface)].copy() |
|
if len(df) < 1: |
|
raise ValueError( |
|
f"Cannot find a chain/interface for evaluation in the PDB." |
|
) |
|
|
|
def get_atom_mask(row): |
|
chain_1_mask = atom_array_chain_id == row["chain_1_id"] |
|
if row["type"] == "chain": |
|
chain_2_mask = chain_1_mask |
|
else: |
|
chain_2_mask = atom_array_chain_id == row["chain_2_id"] |
|
chain_1_mask = torch.tensor(chain_1_mask).bool() |
|
chain_2_mask = torch.tensor(chain_2_mask).bool() |
|
if chain_1_mask.sum() == 0 or chain_2_mask.sum() == 0: |
|
return None, None |
|
return chain_1_mask, chain_2_mask |
|
|
|
df["chain_1_mask"], df["chain_2_mask"] = zip(*df.apply(get_atom_mask, axis=1)) |
|
df = df[df["chain_1_mask"].notna()] |
|
|
|
if len(df) < 1: |
|
raise ValueError( |
|
f"Cannot find a chain/interface for evaluation in the atom_array." |
|
) |
|
|
|
eval_type = np.array(df["eval_type"].tolist()) |
|
cluster_id = np.array(df["cluster_id"].tolist()) |
|
|
|
chain_1_mask = torch.stack(df["chain_1_mask"].tolist()) |
|
|
|
chain_2_mask = torch.stack(df["chain_2_mask"].tolist()) |
|
|
|
return eval_type, cluster_id, chain_1_mask, chain_2_mask |
|
|
|
def get_feature_and_label( |
|
self, |
|
idx: int, |
|
token_array: TokenArray, |
|
atom_array: AtomArray, |
|
msa_features: dict[str, Any], |
|
template_features: dict[str, Any], |
|
full_atom_array: AtomArray, |
|
is_spatial_crop: bool = True, |
|
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: |
|
""" |
|
Get feature and label information for a given data point. |
|
It uses a Featurizer object to obtain input features and labels, and applies several |
|
steps to add other features and labels. Finally, it returns the feature dictionary, label |
|
dictionary, and a full label dictionary. |
|
|
|
Args: |
|
idx: Index of the data point. |
|
token_array: Token array representing the amino acid sequence. |
|
atom_array: Atom array containing atomic information. |
|
msa_features: Dictionary of MSA features. |
|
template_features: Dictionary of template features. |
|
full_atom_array: Full atom array containing all atoms. |
|
is_spatial_crop: Flag indicating whether spatial cropping is applied, by default True. |
|
|
|
Returns: |
|
A tuple containing the feature dictionary and the label dictionary. |
|
|
|
Raises: |
|
ValueError: If the ligand cannot be found in the data point. |
|
""" |
|
|
|
feat = Featurizer( |
|
cropped_token_array=token_array, |
|
cropped_atom_array=atom_array, |
|
ref_pos_augment=self.ref_pos_augment, |
|
lig_atom_rename=self.lig_atom_rename, |
|
) |
|
features_dict = feat.get_all_input_features() |
|
labels_dict = feat.get_labels() |
|
|
|
|
|
features_dict["atom_perm_list"] = feat.get_atom_permutation_list() |
|
|
|
|
|
|
|
label_full_dict, full_atom_array = Featurizer.get_gt_full_complex_features( |
|
atom_array=full_atom_array, |
|
cropped_atom_array=atom_array, |
|
get_cropped_asym_only=is_spatial_crop, |
|
) |
|
|
|
|
|
if self.find_pocket: |
|
|
|
sample_indice = self._get_sample_indice(idx=idx) |
|
if sample_indice.mol_1_type == "ligand": |
|
lig_entity_id = str(sample_indice.entity_1_id) |
|
lig_chain_id = str(sample_indice.chain_1_id) |
|
elif sample_indice.mol_2_type == "ligand": |
|
lig_entity_id = str(sample_indice.entity_2_id) |
|
lig_chain_id = str(sample_indice.chain_2_id) |
|
else: |
|
raise ValueError(f"Cannot find ligand from this data point.") |
|
|
|
assert lig_entity_id in set(atom_array.label_entity_id) |
|
assert lig_chain_id in set(atom_array.chain_id) |
|
|
|
|
|
lig_asym_id = atom_array.label_asym_id[atom_array.chain_id == lig_chain_id] |
|
assert len(np.unique(lig_asym_id)) == 1 |
|
lig_asym_id = lig_asym_id[0] |
|
ligands = [lig_asym_id] |
|
|
|
if self.find_all_pockets: |
|
|
|
all_lig_asym_ids = set( |
|
full_atom_array[ |
|
full_atom_array.label_entity_id == lig_entity_id |
|
].label_asym_id |
|
) |
|
ligands.extend(list(all_lig_asym_ids - set([lig_asym_id]))) |
|
|
|
|
|
|
|
|
|
interested_ligand_mask, pocket_mask = feat.get_lig_pocket_mask( |
|
atom_array=full_atom_array, lig_label_asym_id=ligands |
|
) |
|
|
|
label_full_dict["pocket_mask"] = pocket_mask |
|
label_full_dict["interested_ligand_mask"] = interested_ligand_mask |
|
|
|
|
|
if self.find_eval_chain_interface: |
|
eval_type, cluster_id, chain_1_mask, chain_2_mask = ( |
|
self._get_eval_chain_interface_mask( |
|
idx=idx, atom_array_chain_id=full_atom_array.chain_id |
|
) |
|
) |
|
labels_dict["eval_type"] = eval_type |
|
labels_dict["cluster_id"] = cluster_id |
|
labels_dict["chain_1_mask"] = chain_1_mask |
|
labels_dict["chain_2_mask"] = chain_2_mask |
|
|
|
|
|
dummy_feats = [] |
|
if len(msa_features) == 0: |
|
dummy_feats.append("msa") |
|
else: |
|
msa_features = dict_to_tensor(msa_features) |
|
features_dict.update(msa_features) |
|
if len(template_features) == 0: |
|
dummy_feats.append("template") |
|
else: |
|
template_features = dict_to_tensor(template_features) |
|
features_dict.update(template_features) |
|
|
|
features_dict = make_dummy_feature( |
|
features_dict=features_dict, dummy_feats=dummy_feats |
|
) |
|
|
|
features_dict = data_type_transform(feat_or_label_dict=features_dict) |
|
labels_dict = data_type_transform(feat_or_label_dict=labels_dict) |
|
|
|
|
|
features_dict["is_distillation"] = torch.tensor([self.is_distillation]) |
|
if self.is_distillation is True: |
|
features_dict["resolution"] = torch.tensor([-1.0]) |
|
return features_dict, labels_dict, label_full_dict |
|
|
|
|
|
def get_msa_featurizer(configs, dataset_name: str, stage: str) -> Optional[Callable]: |
|
""" |
|
Creates and returns an MSAFeaturizer object based on the provided configurations. |
|
|
|
Args: |
|
configs: A dictionary containing the configurations for the MSAFeaturizer. |
|
dataset_name: The name of the dataset. |
|
stage: The stage of the dataset (e.g., 'train', 'test'). |
|
|
|
Returns: |
|
An MSAFeaturizer object if MSA is enabled in the configurations, otherwise None. |
|
""" |
|
if "msa" in configs["data"] and configs["data"]["msa"]["enable"]: |
|
msa_info = configs["data"]["msa"] |
|
msa_args = deepcopy(msa_info) |
|
|
|
if "msa" in (dataset_config := configs["data"][dataset_name]): |
|
for k, v in dataset_config["msa"].items(): |
|
if k not in ["prot", "rna"]: |
|
msa_args[k] = v |
|
else: |
|
for kk, vv in dataset_config["msa"][k].items(): |
|
msa_args[k][kk] = vv |
|
|
|
prot_msa_args = msa_args["prot"] |
|
prot_msa_args.update( |
|
{ |
|
"dataset_name": dataset_name, |
|
"merge_method": msa_args["merge_method"], |
|
"max_size": msa_args["max_size"][stage], |
|
} |
|
) |
|
|
|
rna_msa_args = msa_args["rna"] |
|
rna_msa_args.update( |
|
{ |
|
"dataset_name": dataset_name, |
|
"merge_method": msa_args["merge_method"], |
|
"max_size": msa_args["max_size"][stage], |
|
} |
|
) |
|
|
|
return MSAFeaturizer( |
|
prot_msa_args=prot_msa_args, |
|
rna_msa_args=rna_msa_args, |
|
enable_rna_msa=configs.data.msa.enable_rna_msa, |
|
) |
|
|
|
else: |
|
return None |
|
|
|
|
|
class WeightedMultiDataset(Dataset): |
|
""" |
|
A weighted dataset composed of multiple datasets with weights. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
datasets: list[Dataset], |
|
dataset_names: list[str], |
|
datapoint_weights: list[list[float]], |
|
dataset_sample_weights: list[torch.tensor], |
|
): |
|
""" |
|
Initializes the WeightedMultiDataset. |
|
Args: |
|
datasets: A list of Dataset objects. |
|
dataset_names: A list of dataset names corresponding to the datasets. |
|
datapoint_weights: A list of lists containing sampling weights for each datapoint in the datasets. |
|
dataset_sample_weights: A list of torch tensors containing sampling weights for each dataset. |
|
""" |
|
self.datasets = datasets |
|
self.dataset_names = dataset_names |
|
self.datapoint_weights = datapoint_weights |
|
self.dataset_sample_weights = torch.Tensor(dataset_sample_weights) |
|
self.iteration = 0 |
|
self.offset = 0 |
|
self.init_datasets() |
|
|
|
def init_datasets(self): |
|
"""Calculate global weights of each datapoint in datasets for future sampling.""" |
|
self.merged_datapoint_weights = [] |
|
self.weight = 0.0 |
|
self.dataset_indices = [] |
|
self.within_dataset_indices = [] |
|
for dataset_index, ( |
|
dataset, |
|
datapoint_weight_list, |
|
dataset_weight, |
|
) in enumerate( |
|
zip(self.datasets, self.datapoint_weights, self.dataset_sample_weights) |
|
): |
|
|
|
weight_sum = sum(datapoint_weight_list) |
|
datapoint_weight_list = [ |
|
dataset_weight * w / weight_sum for w in datapoint_weight_list |
|
] |
|
self.merged_datapoint_weights.extend(datapoint_weight_list) |
|
self.weight += dataset_weight |
|
self.dataset_indices.extend([dataset_index] * len(datapoint_weight_list)) |
|
self.within_dataset_indices.extend(list(range(len(datapoint_weight_list)))) |
|
self.merged_datapoint_weights = torch.tensor( |
|
self.merged_datapoint_weights, dtype=torch.float64 |
|
) |
|
|
|
def __len__(self) -> int: |
|
return len(self.merged_datapoint_weights) |
|
|
|
def __getitem__(self, index: int) -> dict[str, dict]: |
|
return self.datasets[self.dataset_indices[index]][ |
|
self.within_dataset_indices[index] |
|
] |
|
|
|
|
|
def get_weighted_pdb_weight( |
|
data_type: str, |
|
cluster_size: int, |
|
chain_count: dict, |
|
eps: float = 1e-9, |
|
beta_dict: Optional[dict] = None, |
|
alpha_dict: Optional[dict] = None, |
|
) -> float: |
|
""" |
|
Get sample weight for each example in a weighted PDB dataset. |
|
|
|
data_type (str): Type of data, either 'chain' or 'interface'. |
|
cluster_size (int): Cluster size of this chain/interface. |
|
chain_count (dict): Count of each kind of chains, e.g., {"prot": int, "nuc": int, "ligand": int}. |
|
eps (float, optional): A small epsilon value to avoid division by zero. Default is 1e-9. |
|
beta_dict (Optional[dict], optional): Dictionary containing beta values for 'chain' and 'interface'. |
|
alpha_dict (Optional[dict], optional): Dictionary containing alpha values for different chain types. |
|
|
|
Returns: |
|
float: Calculated weight for the given chain/interface. |
|
""" |
|
if not beta_dict: |
|
beta_dict = { |
|
"chain": 0.5, |
|
"interface": 1, |
|
} |
|
if not alpha_dict: |
|
alpha_dict = { |
|
"prot": 3, |
|
"nuc": 3, |
|
"ligand": 1, |
|
} |
|
|
|
assert cluster_size > 0 |
|
assert data_type in ["chain", "interface"] |
|
beta = beta_dict[data_type] |
|
assert set(chain_count.keys()).issubset(set(alpha_dict.keys())) |
|
weight = ( |
|
beta |
|
* sum( |
|
[alpha * chain_count[data_mode] for data_mode, alpha in alpha_dict.items()] |
|
) |
|
/ (cluster_size + eps) |
|
) |
|
return weight |
|
|
|
|
|
def calc_weights_for_df( |
|
indices_df: pd.DataFrame, beta_dict: dict[str, Any], alpha_dict: dict[str, Any] |
|
) -> pd.DataFrame: |
|
""" |
|
Calculate weights for each example in the dataframe. |
|
|
|
Args: |
|
indices_df: A pandas DataFrame containing the indices. |
|
beta_dict: A dictionary containing beta values for different data types. |
|
alpha_dict: A dictionary containing alpha values for different data types. |
|
|
|
Returns: |
|
A pandas DataFrame with an column 'weights' containing the calculated weights. |
|
""" |
|
|
|
indices_df["pdb_sorted_entity_id"] = indices_df.apply( |
|
lambda x: f"{x['pdb_id']}_{x['assembly_id']}_{'_'.join(sorted([str(x['entity_1_id']), str(x['entity_2_id'])]))}", |
|
axis=1, |
|
) |
|
|
|
entity_member_num_dict = {} |
|
for pdb_sorted_entity_id, sub_df in indices_df.groupby("pdb_sorted_entity_id"): |
|
|
|
entity_member_num_dict[pdb_sorted_entity_id] = len(sub_df) |
|
indices_df["pdb_sorted_entity_id_member_num"] = indices_df.apply( |
|
lambda x: entity_member_num_dict[x["pdb_sorted_entity_id"]], axis=1 |
|
) |
|
|
|
cluster_size_record = {} |
|
for cluster_id, sub_df in indices_df.groupby("cluster_id"): |
|
cluster_size_record[cluster_id] = len(set(sub_df["pdb_sorted_entity_id"])) |
|
|
|
weights = [] |
|
for _, row in indices_df.iterrows(): |
|
data_type = row["type"] |
|
cluster_size = cluster_size_record[row["cluster_id"]] |
|
chain_count = {"prot": 0, "nuc": 0, "ligand": 0} |
|
for mol_type in [row["mol_1_type"], row["mol_2_type"]]: |
|
if chain_count.get(mol_type) is None: |
|
continue |
|
chain_count[mol_type] += 1 |
|
|
|
weight = get_weighted_pdb_weight( |
|
data_type=data_type, |
|
cluster_size=cluster_size, |
|
chain_count=chain_count, |
|
beta_dict=beta_dict, |
|
alpha_dict=alpha_dict, |
|
) |
|
weights.append(weight) |
|
indices_df["weights"] = weights / indices_df["pdb_sorted_entity_id_member_num"] |
|
return indices_df |
|
|
|
|
|
def get_sample_weights( |
|
sampler_type: str, |
|
indices_df: pd.DataFrame = None, |
|
beta_dict: dict = { |
|
"chain": 0.5, |
|
"interface": 1, |
|
}, |
|
alpha_dict: dict = { |
|
"prot": 3, |
|
"nuc": 3, |
|
"ligand": 1, |
|
}, |
|
force_recompute_weight: bool = False, |
|
) -> Union[pd.Series, list[float]]: |
|
""" |
|
Computes sample weights based on the specified sampler type. |
|
|
|
Args: |
|
sampler_type: The type of sampler to use ('weighted' or 'uniform'). |
|
indices_df: A pandas DataFrame containing the indices. |
|
beta_dict: A dictionary containing beta values for different data types. |
|
alpha_dict: A dictionary containing alpha values for different data types. |
|
force_recompute_weight: Whether to force recomputation of weights even if they already exist. |
|
|
|
Returns: |
|
A list of sample weights. |
|
|
|
Raises: |
|
ValueError: If an unknown sampler type is provided. |
|
""" |
|
if sampler_type == "weighted": |
|
assert indices_df is not None |
|
if "weights" not in indices_df.columns or force_recompute_weight: |
|
indices_df = calc_weights_for_df( |
|
indices_df=indices_df, |
|
beta_dict=beta_dict, |
|
alpha_dict=alpha_dict, |
|
) |
|
return indices_df["weights"].astype("float32") |
|
elif sampler_type == "uniform": |
|
assert indices_df is not None |
|
return [1 / len(indices_df) for _ in range(len(indices_df))] |
|
else: |
|
raise ValueError(f"Unknown sampler type: {sampler_type}") |
|
|
|
|
|
def get_datasets( |
|
configs: ConfigDict, error_dir: Optional[str] |
|
) -> tuple[WeightedMultiDataset, dict[str, BaseSingleDataset]]: |
|
""" |
|
Get training and testing datasets given configs |
|
|
|
Args: |
|
configs: A ConfigDict containing the dataset configurations. |
|
error_dir: The directory where error logs will be saved. |
|
|
|
Returns: |
|
A tuple containing the training dataset and a dictionary of testing datasets. |
|
""" |
|
|
|
def _get_dataset_param(config_dict, dataset_name: str, stage: str): |
|
|
|
|
|
return { |
|
"name": dataset_name, |
|
**config_dict["base_info"], |
|
"cropping_configs": config_dict["cropping_configs"], |
|
"error_dir": error_dir, |
|
"msa_featurizer": get_msa_featurizer(configs, dataset_name, stage), |
|
"template_featurizer": None, |
|
"lig_atom_rename": config_dict.get("lig_atom_rename", False), |
|
"shuffle_mols": config_dict.get("shuffle_mols", False), |
|
"shuffle_sym_ids": config_dict.get("shuffle_sym_ids", False), |
|
} |
|
|
|
data_config = configs.data |
|
logger.info(f"Using train sets {data_config.train_sets}") |
|
assert len(data_config.train_sets) == len( |
|
data_config.train_sampler.train_sample_weights |
|
) |
|
train_datasets = [] |
|
datapoint_weights = [] |
|
for train_name in data_config.train_sets: |
|
config_dict = data_config[train_name].to_dict() |
|
dataset_param = _get_dataset_param( |
|
config_dict, dataset_name=train_name, stage="train" |
|
) |
|
dataset_param["ref_pos_augment"] = data_config.get( |
|
"train_ref_pos_augment", True |
|
) |
|
dataset_param["limits"] = data_config.get("limits", -1) |
|
train_dataset = BaseSingleDataset(**dataset_param) |
|
train_datasets.append(train_dataset) |
|
datapoint_weights.append( |
|
get_sample_weights( |
|
**data_config[train_name]["sampler_configs"], |
|
indices_df=train_dataset.indices_list, |
|
) |
|
) |
|
train_dataset = WeightedMultiDataset( |
|
datasets=train_datasets, |
|
dataset_names=data_config.train_sets, |
|
datapoint_weights=datapoint_weights, |
|
dataset_sample_weights=data_config.train_sampler.train_sample_weights, |
|
) |
|
|
|
test_datasets = {} |
|
test_sets = data_config.test_sets |
|
for test_name in test_sets: |
|
config_dict = data_config[test_name].to_dict() |
|
dataset_param = _get_dataset_param( |
|
config_dict, dataset_name=test_name, stage="test" |
|
) |
|
dataset_param["ref_pos_augment"] = data_config.get("test_ref_pos_augment", True) |
|
test_dataset = BaseSingleDataset(**dataset_param) |
|
test_datasets[test_name] = test_dataset |
|
return train_dataset, test_datasets |
|
|