|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import time |
|
import traceback |
|
import warnings |
|
from typing import Any, Mapping |
|
|
|
import torch |
|
from biotite.structure import AtomArray |
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler |
|
|
|
from protenix.data.data_pipeline import DataPipeline |
|
from protenix.data.json_to_feature import SampleDictToFeatures |
|
from protenix.data.msa_featurizer import InferenceMSAFeaturizer |
|
from protenix.data.utils import data_type_transform, make_dummy_feature |
|
from protenix.utils.distributed import DIST_WRAPPER |
|
from protenix.utils.torch_utils import dict_to_tensor |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
warnings.filterwarnings("ignore", module="biotite") |
|
|
|
|
|
def get_inference_dataloader(configs: Any) -> DataLoader: |
|
""" |
|
Creates and returns a DataLoader for inference using the InferenceDataset. |
|
|
|
Args: |
|
configs: A configuration object containing the necessary parameters for the DataLoader. |
|
|
|
Returns: |
|
A DataLoader object configured for inference. |
|
""" |
|
inference_dataset = InferenceDataset( |
|
input_json_path=configs.input_json_path, |
|
dump_dir=configs.dump_dir, |
|
use_msa=configs.use_msa, |
|
) |
|
sampler = DistributedSampler( |
|
dataset=inference_dataset, |
|
num_replicas=DIST_WRAPPER.world_size, |
|
rank=DIST_WRAPPER.rank, |
|
shuffle=False, |
|
) |
|
dataloader = DataLoader( |
|
dataset=inference_dataset, |
|
batch_size=1, |
|
sampler=sampler, |
|
collate_fn=lambda batch: batch, |
|
num_workers=0, |
|
) |
|
return dataloader |
|
|
|
|
|
class InferenceDataset(Dataset): |
|
def __init__( |
|
self, |
|
input_json_path: str, |
|
dump_dir: str, |
|
use_msa: bool = True, |
|
) -> None: |
|
|
|
self.input_json_path = input_json_path |
|
self.dump_dir = dump_dir |
|
self.use_msa = use_msa |
|
with open(self.input_json_path, "r") as f: |
|
self.inputs = json.load(f) |
|
|
|
def process_one( |
|
self, |
|
single_sample_dict: Mapping[str, Any], |
|
) -> tuple[dict[str, torch.Tensor], AtomArray, dict[str, float]]: |
|
""" |
|
Processes a single sample from the input JSON to generate features and statistics. |
|
|
|
Args: |
|
single_sample_dict: A dictionary containing the sample data. |
|
|
|
Returns: |
|
A tuple containing: |
|
- A dictionary of features. |
|
- An AtomArray object. |
|
- A dictionary of time tracking statistics. |
|
""" |
|
|
|
t0 = time.time() |
|
sample2feat = SampleDictToFeatures( |
|
single_sample_dict, |
|
) |
|
features_dict, atom_array, token_array = sample2feat.get_feature_dict() |
|
features_dict["distogram_rep_atom_mask"] = torch.Tensor( |
|
atom_array.distogram_rep_atom_mask |
|
).long() |
|
entity_poly_type = sample2feat.entity_poly_type |
|
t1 = time.time() |
|
|
|
|
|
entity_to_asym_id = DataPipeline.get_label_entity_id_to_asym_id_int(atom_array) |
|
msa_features = ( |
|
InferenceMSAFeaturizer.make_msa_feature( |
|
bioassembly=single_sample_dict["sequences"], |
|
entity_to_asym_id=entity_to_asym_id, |
|
token_array=token_array, |
|
atom_array=atom_array, |
|
) |
|
if self.use_msa |
|
else {} |
|
) |
|
|
|
|
|
dummy_feats = ["template"] |
|
if len(msa_features) == 0: |
|
dummy_feats.append("msa") |
|
else: |
|
msa_features = dict_to_tensor(msa_features) |
|
features_dict.update(msa_features) |
|
features_dict = make_dummy_feature( |
|
features_dict=features_dict, |
|
dummy_feats=dummy_feats, |
|
) |
|
|
|
|
|
feat = data_type_transform(feat_or_label_dict=features_dict) |
|
|
|
t2 = time.time() |
|
|
|
data = {} |
|
data["input_feature_dict"] = feat |
|
|
|
|
|
N_token = feat["token_index"].shape[0] |
|
N_atom = feat["atom_to_token_idx"].shape[0] |
|
N_msa = feat["msa"].shape[0] |
|
|
|
stats = {} |
|
for mol_type in ["ligand", "protein", "dna", "rna"]: |
|
mol_type_mask = feat[f"is_{mol_type}"].bool() |
|
stats[f"{mol_type}/atom"] = int(mol_type_mask.sum(dim=-1).item()) |
|
stats[f"{mol_type}/token"] = len( |
|
torch.unique(feat["atom_to_token_idx"][mol_type_mask]) |
|
) |
|
|
|
N_asym = len(torch.unique(data["input_feature_dict"]["asym_id"])) |
|
data.update( |
|
{ |
|
"N_asym": torch.tensor([N_asym]), |
|
"N_token": torch.tensor([N_token]), |
|
"N_atom": torch.tensor([N_atom]), |
|
"N_msa": torch.tensor([N_msa]), |
|
} |
|
) |
|
|
|
def formatted_key(key): |
|
type_, unit = key.split("/") |
|
if type_ == "protein": |
|
type_ = "prot" |
|
elif type_ == "ligand": |
|
type_ = "lig" |
|
else: |
|
pass |
|
return f"N_{type_}_{unit}" |
|
|
|
data.update( |
|
{ |
|
formatted_key(k): torch.tensor([stats[k]]) |
|
for k in [ |
|
"protein/atom", |
|
"ligand/atom", |
|
"dna/atom", |
|
"rna/atom", |
|
"protein/token", |
|
"ligand/token", |
|
"dna/token", |
|
"rna/token", |
|
] |
|
} |
|
) |
|
data.update({"entity_poly_type": entity_poly_type}) |
|
t3 = time.time() |
|
time_tracker = { |
|
"crop": t1 - t0, |
|
"featurizer": t2 - t1, |
|
"added_feature": t3 - t2, |
|
} |
|
|
|
return data, atom_array, time_tracker |
|
|
|
def __len__(self) -> int: |
|
return len(self.inputs) |
|
|
|
def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], AtomArray, str]: |
|
try: |
|
single_sample_dict = self.inputs[index] |
|
sample_name = single_sample_dict["name"] |
|
logger.info(f"Featurizing {sample_name}...") |
|
|
|
data, atom_array, _ = self.process_one( |
|
single_sample_dict=single_sample_dict |
|
) |
|
error_message = "" |
|
except Exception as e: |
|
data, atom_array = {}, None |
|
error_message = f"{e}:\n{traceback.format_exc()}" |
|
data["sample_name"] = single_sample_dict["name"] |
|
data["sample_index"] = index |
|
return data, atom_array, error_message |
|
|