File size: 7,111 Bytes
13362e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from torch_geometric.data import Batch as PyGBatch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy

def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """
    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

@dataclass
class DataCollatorForSeqGraph:
    """
    Data collator that will dynamically pad the inputs received, as well as the labels.
    """
    tokenizer: PreTrainedTokenizerBase
    mol_id_to_pyg: Dict[str, Any]
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
        if labels is not None and all(label is None for label in labels):
            labels = None

        # Store molecule_ids, retro_labels, and retro_product_ids separately and remove from non_labels_features
        molecule_ids_list = []
        retro_labels_list = []
        retro_products_list = []
        non_labels_features = []
        for feature in features:
            new_feature = {k: v for k, v in feature.items() if k != label_name}
            if 'molecule_ids' in new_feature:
                molecule_ids_list.append(new_feature['molecule_ids'])
                del new_feature['molecule_ids']
            else:
                molecule_ids_list.append(None)
            if 'retro_labels' in new_feature:
                retro_labels_list.append(new_feature['retro_labels'])
                del new_feature['retro_labels']
            else:
                retro_labels_list.append(None)
            if 'retro_product_ids' in new_feature:
                retro_products_list.append(new_feature['retro_product_ids'])
                del new_feature['retro_product_ids']
            else:
                retro_products_list.append(None)
            non_labels_features.append(new_feature)

        # Convert molecule IDs to PyG Data objects
        molecule_graphs_list = []
        design_graphs_list = []
        for seq_idx, molecule_ids in enumerate(molecule_ids_list):
            if molecule_ids is not None and len(molecule_ids) > 0:
                for pos, mol_id in enumerate(molecule_ids):
                    if pos == 0:
                        design_graphs_list.append(self.mol_id_to_pyg[mol_id])
                    if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg:
                        molecule_graphs_list.append(self.mol_id_to_pyg[mol_id])

        # Convert retro_product_ids to PyG Data objects
        retro_product_graphs_list = []
        for seq_idx, retro_product_ids in enumerate(retro_products_list):
            if retro_product_ids is not None and len(retro_product_ids) > 0:
                for pos, mol_id in enumerate(retro_product_ids):
                    if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg:
                        retro_product_graphs_list.append(self.mol_id_to_pyg[mol_id])

        # Batch the PyG Data objects
        if molecule_graphs_list:
            batched_graphs = PyGBatch.from_data_list(molecule_graphs_list)
        else:
            batched_graphs = None
        
        if design_graphs_list:
            batched_design_graphs = PyGBatch.from_data_list(design_graphs_list)
        else:
            batched_design_graphs = None

        if retro_product_graphs_list:
            batched_retro_products = PyGBatch.from_data_list(retro_product_graphs_list)
        else:
            batched_retro_products = None

        # Pad retro_labels
        if retro_labels_list and any(retro_labels is not None for retro_labels in retro_labels_list):
            max_retro_length = max(len(retro_labels) for retro_labels in retro_labels_list if retro_labels is not None)
            padded_retro_labels = [
                retro_labels + [self.label_pad_token_id] * (max_retro_length - len(retro_labels)) if retro_labels is not None else [self.label_pad_token_id] * max_retro_length
                for retro_labels in retro_labels_list
            ]
        else:
            padded_retro_labels = None

        # Pad other features
        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            non_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=return_tensors,
        )

        batch["molecule_graphs"] = batched_graphs
        batch["design_graphs"] = batched_design_graphs
        batch["retro_product_graphs"] = batched_retro_products
        batch["retro_labels"] = torch.tensor(padded_retro_labels, dtype=torch.int64)

        # Pad labels
        if labels is not None:
            max_label_length = max(len(l) for l in labels)
            if self.pad_to_multiple_of is not None:
                max_label_length = (
                    (max_label_length + self.pad_to_multiple_of - 1)
                    // self.pad_to_multiple_of
                    * self.pad_to_multiple_of
                )

            padding_side = self.tokenizer.padding_side
            padded_labels = [
                label + [self.label_pad_token_id] * (max_label_length - len(label))
                if padding_side == "right"
                else [self.label_pad_token_id] * (max_label_length - len(label)) + label
                for label in labels
            ]
            batch["labels"] = torch.tensor(padded_labels, dtype=torch.int64)

        # Prepare decoder_input_ids
        if (
            labels is not None
            and self.model is not None
            and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
        ):
            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
            batch["decoder_input_ids"] = decoder_input_ids

        return batch