|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from collections import defaultdict |
|
from pathlib import Path |
|
from typing import Any, Optional, Union |
|
|
|
import biotite.structure.io as strucio |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from biotite.structure import AtomArray |
|
|
|
from protenix.data.msa_featurizer import MSAFeaturizer |
|
from protenix.data.parser import DistillationMMCIFParser, MMCIFParser |
|
from protenix.data.tokenizer import AtomArrayTokenizer, TokenArray |
|
from protenix.utils.cropping import CropData |
|
from protenix.utils.file_io import load_gzip_pickle |
|
|
|
torch.multiprocessing.set_sharing_strategy("file_system") |
|
|
|
|
|
class DataPipeline(object): |
|
""" |
|
DataPipeline class provides static methods to handle various data processing tasks related to bioassembly structures. |
|
""" |
|
|
|
@staticmethod |
|
def get_data_from_mmcif( |
|
mmcif: Union[str, Path], |
|
pdb_cluster_file: Union[str, Path, None] = None, |
|
dataset: str = "WeightedPDB", |
|
) -> tuple[list[dict[str, Any]], dict[str, Any]]: |
|
""" |
|
Get raw data from mmcif with tokenizer and a list of chains and interfaces for sampling. |
|
|
|
Args: |
|
mmcif (Union[str, Path]): The raw mmcif file. |
|
pdb_cluster_file (Union[str, Path, None], optional): Cluster info txt file. Defaults to None. |
|
dataset (str, optional): The dataset type, either "WeightedPDB" or "Distillation". Defaults to "WeightedPDB". |
|
|
|
Returns: |
|
tuple[list[dict[str, Any]], dict[str, Any]]: |
|
sample_indices_list (list[dict[str, Any]]): The sample indices list (each one is a chain or an interface). |
|
bioassembly_dict (dict[str, Any]): The bioassembly dict with sequence, atom_array, and token_array. |
|
""" |
|
|
|
if dataset == "WeightedPDB": |
|
parser = MMCIFParser(mmcif_file=mmcif) |
|
bioassembly_dict = parser.get_bioassembly() |
|
elif dataset == "Distillation": |
|
parser = DistillationMMCIFParser(mmcif_file=mmcif) |
|
bioassembly_dict = parser.get_structure_dict() |
|
else: |
|
raise NotImplementedError( |
|
'Unsupported "dataset", please input either "WeightedPDB" or "Distillation".' |
|
) |
|
|
|
sample_indices_list = parser.make_indices( |
|
bioassembly_dict=bioassembly_dict, pdb_cluster_file=pdb_cluster_file |
|
) |
|
if len(sample_indices_list) == 0: |
|
|
|
return [], bioassembly_dict |
|
|
|
atom_array = bioassembly_dict["atom_array"] |
|
atom_array.set_annotation( |
|
"resolution", [parser.resolution] * len(atom_array) |
|
) |
|
|
|
tokenizer = AtomArrayTokenizer(atom_array) |
|
token_array = tokenizer.get_token_array() |
|
bioassembly_dict["msa_features"] = None |
|
bioassembly_dict["template_features"] = None |
|
|
|
bioassembly_dict["token_array"] = token_array |
|
return sample_indices_list, bioassembly_dict |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def get_label_entity_id_to_asym_id_int(atom_array: AtomArray) -> dict[str, int]: |
|
""" |
|
Get a dictionary that associates each label_entity_id with its corresponding asym_id_int. |
|
|
|
Args: |
|
atom_array (AtomArray): AtomArray object |
|
|
|
Returns: |
|
dict[str, int]: label_entity_id to its asym_id_int |
|
""" |
|
entity_to_asym_id = defaultdict(set) |
|
for atom in atom_array: |
|
entity_id = atom.label_entity_id |
|
entity_to_asym_id[entity_id].add(atom.asym_id_int) |
|
return entity_to_asym_id |
|
|
|
@staticmethod |
|
def get_data_bioassembly( |
|
bioassembly_dict_fpath: Union[str, Path], |
|
) -> dict[str, Any]: |
|
""" |
|
Get the bioassembly dict. |
|
|
|
Args: |
|
bioassembly_dict_fpath (Union[str, Path]): The path to the bioassembly dictionary file. |
|
|
|
Returns: |
|
dict[str, Any]: The bioassembly dict with sequence, atom_array and token_array. |
|
|
|
Raises: |
|
AssertionError: If the bioassembly dictionary file does not exist. |
|
""" |
|
assert os.path.exists( |
|
bioassembly_dict_fpath |
|
), f"File not exists {bioassembly_dict_fpath}" |
|
bioassembly_dict = load_gzip_pickle(bioassembly_dict_fpath) |
|
|
|
return bioassembly_dict |
|
|
|
@staticmethod |
|
def _map_ref_chain( |
|
one_sample: pd.Series, bioassembly_dict: dict[str, Any] |
|
) -> list[int]: |
|
""" |
|
Map the chain or interface chain_x_id to the reference chain asym_id. |
|
|
|
Args: |
|
one_sample (pd.Series): A dict of one chain or interface from indices list. |
|
bioassembly_dict (dict[str, Any]): The bioassembly dict with sequence, atom_array and token_array. |
|
|
|
Returns: |
|
list[int]: A list of asym_id_lnt of the chosen chain or interface, length 1 or 2. |
|
""" |
|
atom_array = bioassembly_dict["atom_array"] |
|
ref_chain_indices = [] |
|
for chain_id_field in ["chain_1_id", "chain_2_id"]: |
|
chain_id = one_sample[chain_id_field] |
|
assert np.isin( |
|
chain_id, np.unique(atom_array.chain_id) |
|
), f"PDB {bioassembly_dict['pdb_id']} {chain_id_field}:{chain_id} not in atom_array" |
|
chain_asym_id = atom_array[atom_array.chain_id == chain_id].asym_id_int[0] |
|
ref_chain_indices.append(chain_asym_id) |
|
if one_sample["type"] == "chain": |
|
break |
|
return ref_chain_indices |
|
|
|
@staticmethod |
|
def get_msa_raw_features( |
|
bioassembly_dict: dict[str, Any], |
|
selected_indices: np.ndarray, |
|
msa_featurizer: Optional[MSAFeaturizer], |
|
) -> Optional[dict[str, np.ndarray]]: |
|
""" |
|
Get tokenized MSA features of the bioassembly |
|
|
|
Args: |
|
bioassembly_dict (Mapping[str, Any]): The bioassembly dict with sequence, atom_array and token_array. |
|
selected_indices (torch.Tensor): Cropped token indices. |
|
msa_featurizer (MSAFeaturizer): MSAFeaturizer instance. |
|
|
|
Returns: |
|
Optional[dict[str, np.ndarray]]: The tokenized MSA features of the bioassembly. |
|
""" |
|
if msa_featurizer is None: |
|
return None |
|
|
|
entity_to_asym_id_int = dict( |
|
DataPipeline.get_label_entity_id_to_asym_id_int( |
|
bioassembly_dict["atom_array"] |
|
) |
|
) |
|
|
|
msa_feats = msa_featurizer( |
|
bioassembly_dict=bioassembly_dict, |
|
selected_indices=selected_indices, |
|
entity_to_asym_id_int=entity_to_asym_id_int, |
|
) |
|
|
|
return msa_feats |
|
|
|
@staticmethod |
|
def get_template_raw_features( |
|
bioassembly_dict: dict[str, Any], |
|
selected_indices: np.ndarray, |
|
template_featurizer: None, |
|
) -> Optional[dict[str, np.ndarray]]: |
|
""" |
|
Get tokenized template features of the bioassembly. |
|
|
|
Args: |
|
bioassembly_dict (dict[str, Any]): The bioassembly dict with sequence, atom_array and token_array. |
|
selected_indices (np.ndarray): Cropped token indices. |
|
template_featurizer (None): Placeholder for the template featurizer. |
|
|
|
Returns: |
|
Optional[dict[str, np.ndarray]]: The tokenized template features of the bioassembly, |
|
or None if the template featurizer is not provided. |
|
""" |
|
if template_featurizer is None: |
|
return None |
|
|
|
entity_to_asym_id_int = dict( |
|
DataPipeline.get_label_entity_id_to_asym_id_int( |
|
bioassembly_dict["atom_array"] |
|
) |
|
) |
|
|
|
template_feats = template_featurizer( |
|
bioassembly_dict=bioassembly_dict, |
|
selected_indices=selected_indices, |
|
entity_to_asym_id_int=entity_to_asym_id_int, |
|
) |
|
return template_feats |
|
|
|
@staticmethod |
|
def crop( |
|
one_sample: pd.Series, |
|
bioassembly_dict: dict[str, Any], |
|
crop_size: int, |
|
msa_featurizer: Optional[MSAFeaturizer], |
|
template_featurizer: None, |
|
method_weights: list[float] = [0.2, 0.4, 0.4], |
|
contiguous_crop_complete_lig: bool = False, |
|
spatial_crop_complete_lig: bool = False, |
|
drop_last: bool = False, |
|
remove_metal: bool = False, |
|
) -> tuple[str, TokenArray, AtomArray, dict[str, Any], dict[str, Any]]: |
|
""" |
|
Crop data based on the crop size and reference chain indices. |
|
|
|
Args: |
|
one_sample (pd.Series): A dict of one chain or interface from indices list. |
|
bioassembly_dict (dict[str, Any]): A dict of bioassembly dict with sequence, atom_array and token_array. |
|
crop_size (int): the crop size. |
|
msa_featurizer (MSAFeaturizer): Default to an empty replacement for msa featurizer. |
|
template_featurizer (None): Placeholder for the template featurizer. |
|
method_weights (list[float]): The weights corresponding to these three cropping methods: |
|
["ContiguousCropping", "SpatialCropping", "SpatialInterfaceCropping"]. |
|
contiguous_crop_complete_lig (bool): Whether to crop the complete ligand in ContiguousCropping method. |
|
spatial_crop_complete_lig (bool): Whether to crop the complete ligand in SpatialCropping method. |
|
drop_last (bool): Whether to drop the last fragment in ContiguousCropping. |
|
remove_metal (bool): Whether to remove metal atoms from the crop. |
|
|
|
Returns: |
|
tuple[str, TokenArray, AtomArray, dict[str, Any], dict[str, Any]]: |
|
crop_method (str): The crop method. |
|
cropped_token_array (TokenArray): TokenArray after cropping. |
|
cropped_atom_array (AtomArray): AtomArray after cropping. |
|
cropped_msa_features (dict[str, Any]): The cropped msa features. |
|
cropped_template_features (dict[str, Any]): The cropped template features. |
|
""" |
|
if crop_size <= 0: |
|
selected_indices = None |
|
|
|
msa_features = DataPipeline.get_msa_raw_features( |
|
bioassembly_dict=bioassembly_dict, |
|
selected_indices=selected_indices, |
|
msa_featurizer=msa_featurizer, |
|
) |
|
|
|
template_features = DataPipeline.get_template_raw_features( |
|
bioassembly_dict=bioassembly_dict, |
|
selected_indices=selected_indices, |
|
template_featurizer=template_featurizer, |
|
) |
|
return ( |
|
"no_crop", |
|
bioassembly_dict["token_array"], |
|
bioassembly_dict["atom_array"], |
|
msa_features or {}, |
|
template_features or {}, |
|
-1, |
|
) |
|
|
|
ref_chain_indices = DataPipeline._map_ref_chain( |
|
one_sample=one_sample, bioassembly_dict=bioassembly_dict |
|
) |
|
|
|
crop = CropData( |
|
crop_size=crop_size, |
|
ref_chain_indices=ref_chain_indices, |
|
token_array=bioassembly_dict["token_array"], |
|
atom_array=bioassembly_dict["atom_array"], |
|
method_weights=method_weights, |
|
contiguous_crop_complete_lig=contiguous_crop_complete_lig, |
|
spatial_crop_complete_lig=spatial_crop_complete_lig, |
|
drop_last=drop_last, |
|
remove_metal=remove_metal, |
|
) |
|
|
|
crop_method = crop.random_crop_method() |
|
|
|
selected_indices, reference_token_index = crop.get_crop_indices( |
|
crop_method=crop_method |
|
) |
|
|
|
msa_features = DataPipeline.get_msa_raw_features( |
|
bioassembly_dict=bioassembly_dict, |
|
selected_indices=selected_indices, |
|
msa_featurizer=msa_featurizer, |
|
) |
|
|
|
template_features = DataPipeline.get_template_raw_features( |
|
bioassembly_dict=bioassembly_dict, |
|
selected_indices=selected_indices, |
|
template_featurizer=template_featurizer, |
|
) |
|
|
|
( |
|
cropped_token_array, |
|
cropped_atom_array, |
|
cropped_msa_features, |
|
cropped_template_features, |
|
) = crop.crop_by_indices( |
|
selected_token_indices=selected_indices, |
|
msa_features=msa_features, |
|
template_features=template_features, |
|
) |
|
|
|
if crop_method == "ContiguousCropping": |
|
resovled_atom_num = cropped_atom_array.is_resolved.sum() |
|
|
|
assert ( |
|
resovled_atom_num > 4 |
|
), f"{resovled_atom_num=} <= 4 after ContiguousCropping" |
|
|
|
return ( |
|
crop_method, |
|
cropped_token_array, |
|
cropped_atom_array, |
|
cropped_msa_features, |
|
cropped_template_features, |
|
reference_token_index, |
|
) |
|
|
|
@staticmethod |
|
def save_atoms_to_cif( |
|
output_cif_file: str, atom_array: AtomArray, include_bonds: bool = False |
|
) -> None: |
|
""" |
|
Save atom array data to a CIF file. |
|
|
|
Args: |
|
output_cif_file (str): The output path for saving atom array in cif |
|
atom_array (AtomArray): The atom array to be saved |
|
include_bonds (bool): Whether to include bond information in the CIF file. Default is False. |
|
|
|
""" |
|
strucio.save_structure( |
|
file_path=output_cif_file, |
|
array=atom_array, |
|
data_block=os.path.basename(output_cif_file).replace(".cif", ""), |
|
include_bonds=include_bonds, |
|
) |
|
|