|
import torch |
|
import numpy as np |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
from torch_geometric.data import Batch as PyGBatch |
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
|
from transformers.utils import PaddingStrategy |
|
|
|
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs): |
|
""" |
|
Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer. |
|
""" |
|
|
|
if not hasattr(tokenizer, "deprecation_warnings"): |
|
return tokenizer.pad(*pad_args, **pad_kwargs) |
|
|
|
|
|
warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False) |
|
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True |
|
|
|
try: |
|
padded = tokenizer.pad(*pad_args, **pad_kwargs) |
|
finally: |
|
|
|
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state |
|
|
|
return padded |
|
|
|
@dataclass |
|
class DataCollatorForSeqGraph: |
|
""" |
|
Data collator that will dynamically pad the inputs received, as well as the labels. |
|
""" |
|
tokenizer: PreTrainedTokenizerBase |
|
mol_id_to_pyg: Dict[str, Any] |
|
model: Optional[Any] = None |
|
padding: Union[bool, str, PaddingStrategy] = True |
|
max_length: Optional[int] = None |
|
pad_to_multiple_of: Optional[int] = None |
|
label_pad_token_id: int = -100 |
|
return_tensors: str = "pt" |
|
|
|
def __call__(self, features, return_tensors=None): |
|
if return_tensors is None: |
|
return_tensors = self.return_tensors |
|
|
|
label_name = "label" if "label" in features[0].keys() else "labels" |
|
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None |
|
if labels is not None and all(label is None for label in labels): |
|
labels = None |
|
|
|
|
|
molecule_ids_list = [] |
|
retro_labels_list = [] |
|
retro_products_list = [] |
|
non_labels_features = [] |
|
for feature in features: |
|
new_feature = {k: v for k, v in feature.items() if k != label_name} |
|
if 'molecule_ids' in new_feature: |
|
molecule_ids_list.append(new_feature['molecule_ids']) |
|
del new_feature['molecule_ids'] |
|
else: |
|
molecule_ids_list.append(None) |
|
if 'retro_labels' in new_feature: |
|
retro_labels_list.append(new_feature['retro_labels']) |
|
del new_feature['retro_labels'] |
|
else: |
|
retro_labels_list.append(None) |
|
if 'retro_product_ids' in new_feature: |
|
retro_products_list.append(new_feature['retro_product_ids']) |
|
del new_feature['retro_product_ids'] |
|
else: |
|
retro_products_list.append(None) |
|
non_labels_features.append(new_feature) |
|
|
|
|
|
molecule_graphs_list = [] |
|
design_graphs_list = [] |
|
for seq_idx, molecule_ids in enumerate(molecule_ids_list): |
|
if molecule_ids is not None and len(molecule_ids) > 0: |
|
for pos, mol_id in enumerate(molecule_ids): |
|
if pos == 0: |
|
design_graphs_list.append(self.mol_id_to_pyg[mol_id]) |
|
if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg: |
|
molecule_graphs_list.append(self.mol_id_to_pyg[mol_id]) |
|
|
|
|
|
retro_product_graphs_list = [] |
|
for seq_idx, retro_product_ids in enumerate(retro_products_list): |
|
if retro_product_ids is not None and len(retro_product_ids) > 0: |
|
for pos, mol_id in enumerate(retro_product_ids): |
|
if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg: |
|
retro_product_graphs_list.append(self.mol_id_to_pyg[mol_id]) |
|
|
|
|
|
if molecule_graphs_list: |
|
batched_graphs = PyGBatch.from_data_list(molecule_graphs_list) |
|
else: |
|
batched_graphs = None |
|
|
|
if design_graphs_list: |
|
batched_design_graphs = PyGBatch.from_data_list(design_graphs_list) |
|
else: |
|
batched_design_graphs = None |
|
|
|
if retro_product_graphs_list: |
|
batched_retro_products = PyGBatch.from_data_list(retro_product_graphs_list) |
|
else: |
|
batched_retro_products = None |
|
|
|
|
|
if retro_labels_list and any(retro_labels is not None for retro_labels in retro_labels_list): |
|
max_retro_length = max(len(retro_labels) for retro_labels in retro_labels_list if retro_labels is not None) |
|
padded_retro_labels = [ |
|
retro_labels + [self.label_pad_token_id] * (max_retro_length - len(retro_labels)) if retro_labels is not None else [self.label_pad_token_id] * max_retro_length |
|
for retro_labels in retro_labels_list |
|
] |
|
else: |
|
padded_retro_labels = None |
|
|
|
|
|
batch = pad_without_fast_tokenizer_warning( |
|
self.tokenizer, |
|
non_labels_features, |
|
padding=self.padding, |
|
max_length=self.max_length, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors=return_tensors, |
|
) |
|
|
|
batch["molecule_graphs"] = batched_graphs |
|
batch["design_graphs"] = batched_design_graphs |
|
batch["retro_product_graphs"] = batched_retro_products |
|
batch["retro_labels"] = torch.tensor(padded_retro_labels, dtype=torch.int64) |
|
|
|
|
|
if labels is not None: |
|
max_label_length = max(len(l) for l in labels) |
|
if self.pad_to_multiple_of is not None: |
|
max_label_length = ( |
|
(max_label_length + self.pad_to_multiple_of - 1) |
|
// self.pad_to_multiple_of |
|
* self.pad_to_multiple_of |
|
) |
|
|
|
padding_side = self.tokenizer.padding_side |
|
padded_labels = [ |
|
label + [self.label_pad_token_id] * (max_label_length - len(label)) |
|
if padding_side == "right" |
|
else [self.label_pad_token_id] * (max_label_length - len(label)) + label |
|
for label in labels |
|
] |
|
batch["labels"] = torch.tensor(padded_labels, dtype=torch.int64) |
|
|
|
|
|
if ( |
|
labels is not None |
|
and self.model is not None |
|
and hasattr(self.model, "prepare_decoder_input_ids_from_labels") |
|
): |
|
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"]) |
|
batch["decoder_input_ids"] = decoder_input_ids |
|
|
|
return batch |