File size: 18,849 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0473a1
c0ec7e6
 
 
d0473a1
c0ec7e6
 
d0473a1
 
 
c0ec7e6
d0473a1
 
 
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
import re
from functools import partial
from numbers import Number
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union, Literal

from lightning import LightningDataModule
import pandas as pd
import swifter
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader

from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler
from deepscreen.utils import get_logger

log = get_logger(__name__)

SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]"
FASTA_PAT = r"[^A-Z*\-]"


def validate_seq_str(seq, regex):
    if seq:
        err_charset = set(re.findall(regex, seq))
        if not err_charset:
            return None
        else:
            return ', '.join(err_charset)
    else:
        return 'Empty string'


# TODO: save a list of corrupted records

def rdkit_canonicalize(smiles):
    from rdkit import Chem
    try:
        mol = Chem.MolFromSmiles(smiles)
        cano_smiles = Chem.MolToSmiles(mol)
        return cano_smiles
    except Exception as e:
        log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}')
        return smiles


class DTIDataset(Dataset):
    def __init__(
            self,
            task: Literal['regression', 'binary', 'multiclass'],
            num_classes: Optional[int],
            data_path: str | Path,
            drug_featurizer: callable,
            protein_featurizer: callable,
            thresholds: Optional[Union[Number, Sequence[Number]]] = None,
            discard_intermediate: Optional[bool] = False,
            query: Optional[str] = 'X2'
    ):
        df = pd.read_csv(
            data_path,
            engine='python',
            header=0,
            usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'],
            dtype={
                'X1': 'str',
                'ID1': 'str',
                'X2': 'str',
                'ID2': 'str',
                'Y': 'float32',
                'U': 'str',
            },
        )
        # Read the whole data table

        # if 'ID1' in df:
        #     self.x1_to_id1 = dict(zip(df['X1'], df['ID1']))
        # if 'ID2' in df:
        #     self.x2_to_id2 = dict(zip(df['X2'], df['ID2']))
        #     self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2']))))
        # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2']))))

        # # train and eval mode data processing (fully labelled)
        # if 'Y' in df.columns and df['Y'].notnull().all():
        log.info(f"Processing data file: {data_path}")

        # Forward-fill all non-label columns
        df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)

        # TODO potentially allow running through the whole data validation process
        # error = False

        if 'Y' in df:
            log.info(f"Validating labels (`Y`)...")
            # TODO: check sklearn.utils.multiclass.check_classification_targets
            match task:
                case 'regression':
                    assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
                        f"""`Y` must be numeric for `regression` task,
                        but it has {set(df['Y'].swifter.apply(type))}."""

                case 'binary':
                    if all(df['Y'].isin([0, 1])):
                        assert not thresholds, \
                            f"""`Y` is already 0 or 1 for `binary` (classification) `task`,
                            but still got `thresholds` ({thresholds}).
                            Double check your choices of `task` and `thresholds`, and records in the `Y` column."""
                    else:
                        assert thresholds, \
                            f"""`Y` must be 0 or 1 for `binary` (classification) `task`,
                            but it has {pd.unique(df['Y'])}. 
                            You may set `thresholds` to discretize continuous labels."""  # TODO print err idx instead

                case 'multiclass':
                    assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.'

                    if all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)):
                        assert not thresholds, \
                            f"""`Y` is already non-negative integers for 
                            `multiclass` (classification) `task`, but still got `thresholds` ({thresholds}).
                            Double check your choice of `task`, `thresholds` and records in the `Y` column."""
                    else:
                        assert thresholds, \
                            f"""`Y` must be non-negative integers for
                            `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}.
                            You must set `thresholds` to discretize continuous labels."""  # TODO print err idx instead

            if 'U' in df.columns:
                units = df['U']
            else:
                units = None
                log.warning("Units ('U') not in the data table. "
                            "Assuming all labels to be discrete or in p-scale (-log10[M]).")

            # Transform labels
            df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds,
                                      discard_intermediate=discard_intermediate)

            # Filter out rows with a NaN in Y (missing values)
            df.dropna(subset=['Y'], inplace=True)

            match task:
                case 'regression':
                    df['Y'] = df['Y'].astype('float32')
                    assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
                        f"""`Y` must be numeric for `regression` task, 
                        but after transformation it still has {set(df['Y'].swifter.apply(type))}.
                        Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
                    # TODO print err idx instead
                case 'binary':
                    df['Y'] = df['Y'].astype('int')
                    assert all(df['Y'].isin([0, 1])), \
                        f"""`Y` must be 0 or 1 for `task=binary`, "
                        but after transformation it still has {pd.unique(df['Y'])}.
                        Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
                    # TODO print err idx instead
                case 'multiclass':
                    df['Y'] = df['Y'].astype('int')
                    assert all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)), \
                        f"""Y must be non-negative integers for `task=multiclass`
                        but after transformation it still has {pd.unique(df['Y'])}.
                        Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
                    # TODO print err idx instead
                    target_n_unique = df['Y'].nunique()
                    assert target_n_unique == num_classes, \
                        f"""You have set `num_classes` for `task=multiclass` to {num_classes},
                        but after transformation Y still has {target_n_unique} unique labels.
                        Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""

        log.info("Validating SMILES (`X1`)...")
        df['X1_ERR'] = df['X1'].swifter.progress_bar(
            desc="Validating SMILES...").apply(validate_seq_str, regex=SMILES_PAT)
        if not df['X1_ERR'].isna().all():
            raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}")
        df['X1^'] = df['X1'].apply(rdkit_canonicalize)  # swifter

        log.info("Validating FASTA (`X2`)...")
        df['X2'] = df['X2'].str.upper()
        df['X2_ERR'] = df['X2'].swifter.progress_bar(
            desc="Validating FASTA...").apply(validate_seq_str, regex=FASTA_PAT)
        if not df['X2_ERR'].isna().all():
            raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}")

        # FASTA/SMILES indices as query for retrieval metrics like enrichment factor and hit rate
        if query:
            df['ID^'] = LabelEncoder().fit_transform(df[query])

        self.df = df
        self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
        self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x)

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, i):
        sample = self.df.loc[i]
        sample_dict = {
            'N': i,
            'X1': sample['X1'],
            'X1^': self.drug_featurizer(sample['X1^']),
            # 'ID1': sample.get('ID1'),
            'X2': sample['X2'],
            'X2^': self.protein_featurizer(sample['X2']),
            # 'ID2': sample.get('ID2'),
            # 'Y': sample.get('Y'),
            # 'ID^': sample.get('ID^'),
        }
        optional_keys = ['ID1', 'ID2', 'ID^', 'Y']
        sample_dict.update({key: sample[key] for key in optional_keys if sample.get(key) is not None})
        return sample_dict


class DTIDataModule(LightningDataModule):
    """
    DTI DataModule

    A DataModule implements 5 key methods:

        def prepare_data(self):
            # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
            # download data, pre-process, split, save to disk, etc.
        def setup(self, stage):
            # things to do on every process in DDP
            # load data, set variables, etc.
        def train_dataloader(self):
            # return train dataloader
        def val_dataloader(self):
            # return validation dataloader
        def test_dataloader(self):
            # return test dataloader
        def teardown(self):
            # called on every process in DDP
            # clean up after fit or test

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
    """

    def __init__(
            self,
            task: Literal['regression', 'binary', 'multiclass'],
            num_classes: Optional[int],
            batch_size: int,
            # train: bool,
            drug_featurizer: callable,
            protein_featurizer: callable,
            collator: callable = collate_fn,
            data_dir: str = "data/",
            data_file: Optional[str] = None,
            train_val_test_split: Optional[Union[Sequence[Number | str]]] = None,
            split: Optional[callable] = None,
            thresholds: Optional[Union[Number, Sequence[Number]]] = None,
            discard_intermediate: Optional[bool] = False,
            num_workers: int = 0,
            pin_memory: bool = False,
    ):
        super().__init__()

        self.train_data: Optional[Dataset] = None
        self.val_data: Optional[Dataset] = None
        self.test_data: Optional[Dataset] = None
        self.predict_data: Optional[Dataset] = None
        self.split = split
        self.collator = collator
        self.dataset = partial(
            DTIDataset,
            task=task,
            num_classes=num_classes,
            drug_featurizer=drug_featurizer,
            protein_featurizer=protein_featurizer,
            thresholds=thresholds,
            discard_intermediate=discard_intermediate
        )

        # this line allows to access init params with 'self.hparams' ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)  # ignore=['split']

    def prepare_data(self):
        """
        Download data if needed.
        Do not use it to assign state (e.g., self.x = x).
        """

    def setup(self, stage: Optional[str] = None, encoding: str = None):
        """
        Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
        This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
        careful not to execute data splitting twice.
        """
        # load and split datasets only if not loaded in initialization
        if not any([self.train_data, self.test_data, self.val_data, self.predict_data]):
            if self.hparams.train_val_test_split:
                if len(self.hparams.train_val_test_split) != 3:
                    raise ValueError('Length of `train_val_test_split` must be 3. '
                                     'Set the second element to None for training without validation. '
                                     'Set the third element to None for training without testing.')

                self.train_data = self.hparams.train_val_test_split[0]
                self.val_data = self.hparams.train_val_test_split[1]
                self.test_data = self.hparams.train_val_test_split[2]

                if all([self.hparams.data_file, self.split]):
                    if all(isinstance(split, Number) or split is None
                           for split in self.hparams.train_val_test_split):
                        split_data = self.split(
                            dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)),
                            lengths=[split for split in self.hparams.train_val_test_split if split is not None]
                        )
                        for dataset in ['train_data', 'val_data', 'test_data']:
                            if getattr(self, dataset) is not None:
                                setattr(self, dataset, split_data.pop(0))

                    else:
                        raise ValueError('`train_val_test_split` must be a sequence numbers or None'
                                         '(float for percentages and int for sample numbers) '
                                         'if both `data_file` and `split` have been specified.')

                elif (all(isinstance(split, str) or split is None
                          for split in self.hparams.train_val_test_split)
                      and not any([self.hparams.data_file, self.split])):
                    for dataset in ['train_data', 'val_data', 'test_data']:
                        if getattr(self, dataset) is not None:
                            data_path = Path(getattr(self, dataset))
                            if not data_path.is_absolute():
                                data_path = Path(self.hparams.data_dir, data_path)
                            setattr(self, dataset, self.dataset(data_path=data_path))

                else:
                    raise ValueError('For training, you must specify either all of `data_file`, `split`, '
                                     'and `train_val_test_split` as a sequence of numbers or '
                                     'solely `train_val_test_split` as a sequence of data file paths.')

            elif self.hparams.data_file and not any([self.split, self.hparams.train_val_test_split]):
                data_path = Path(self.hparams.data_file)
                if not data_path.is_absolute():
                    data_path = Path(self.hparams.data_dir, data_path)
                self.test_data = self.predict_data = self.dataset(data_path=data_path)

            else:
                raise ValueError("For training, you must specify `train_val_test_split`. "
                                 "For testing/predicting, you must specify only `data_file` without "
                                 "`train_val_test_split` or `split`.")

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_data,
            batch_sampler=SafeBatchSampler(
                data_source=self.train_data,
                batch_size=self.hparams.batch_size,
                # Dropping the last batch prevents problems caused by variable batch sizes in training, e.g.,
                # batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs.
                drop_last=True,
                shuffle=True,
            ),
            # batch_size=self.hparams.batch_size,
            # shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.collator,
            persistent_workers=True if self.hparams.num_workers > 0 else False
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_data,
            batch_sampler=SafeBatchSampler(
                data_source=self.val_data,
                batch_size=self.hparams.batch_size,
                drop_last=False,
                shuffle=False
            ),
            # batch_size=self.hparams.batch_size,
            # shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.collator,
            persistent_workers=True if self.hparams.num_workers > 0 else False
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_data,
            batch_sampler=SafeBatchSampler(
                data_source=self.test_data,
                batch_size=self.hparams.batch_size,
                drop_last=False,
                shuffle=False
            ),
            # batch_size=self.hparams.batch_size,
            # shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.collator,
            persistent_workers=True if self.hparams.num_workers > 0 else False
        )

    def predict_dataloader(self):
        return DataLoader(
            dataset=self.predict_data,
            batch_sampler=SafeBatchSampler(
                data_source=self.predict_data,
                batch_size=self.hparams.batch_size,
                drop_last=False,
                shuffle=False
            ),
            # batch_size=self.hparams.batch_size,
            # shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.collator,
            persistent_workers=True if self.hparams.num_workers > 0 else False
        )

    def teardown(self, stage: Optional[str] = None):
        """Clean up after fit or test."""
        pass

    def state_dict(self):
        """Extra things to save to checkpoint."""
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]):
        """Things to do when loading checkpoint."""
        pass