|
from collections import OrderedDict |
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
import pandas as pd |
|
|
|
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel |
|
from esm import BatchConverter, pretrained |
|
|
|
|
|
class RandomCropBatchConverter(BatchConverter): |
|
"""Callable to convert an unprocessed (labels + strings) batch to a |
|
processed (labels + tensor) batch. |
|
For sequences over max_len, randomly crop a window. |
|
""" |
|
|
|
def __init__(self, alphabet, max_len): |
|
super(RandomCropBatchConverter, self).__init__(alphabet) |
|
self.max_len = max_len |
|
|
|
def __call__(self, raw_batch): |
|
cropped_batch = [(label, self._crop(seq)) for label, seq in raw_batch] |
|
return super(RandomCropBatchConverter, self).__call__(cropped_batch) |
|
|
|
def _crop(self, seq): |
|
if len(seq) <= self.max_len: |
|
return seq |
|
start_idx = np.random.choice(len(seq) - self.max_len + 1) |
|
return seq[start_idx: start_idx+self.max_len] |
|
|
|
|
|
class CosupRandomCropBatchConverter(RandomCropBatchConverter): |
|
def __call__(self, raw_batch): |
|
datatype = [tup[0] for tup in raw_batch] |
|
assert len(np.unique(datatype)) == 1 |
|
datatype = datatype[0] |
|
raw_batch = [tup[1] for tup in raw_batch] |
|
return datatype, super(CosupRandomCropBatchConverter, |
|
self).__call__(raw_batch) |
|
|
|
|
|
class CSVBatchedDataset(FastaBatchedDataset): |
|
@classmethod |
|
def from_file(cls, csv_file): |
|
df = pd.read_csv(csv_file) |
|
return cls(df.log_fitness.values, df.seq.values) |
|
|
|
@classmethod |
|
def from_dataframe(cls, df): |
|
return cls(df.log_fitness.values, df.seq.values) |
|
|
|
|
|
class CosupervisionDataset(object): |
|
def __init__(self, fasta_batched_dataset, csv_batched_dataset): |
|
self.datasets = OrderedDict() |
|
self.datasets['unsup'] = fasta_batched_dataset |
|
self.datasets['sup'] = csv_batched_dataset |
|
self.lens = OrderedDict() |
|
for k, d in self.datasets.items(): |
|
self.lens[k] = len(d) |
|
|
|
def __len__(self): |
|
return sum(self.lens.values()) |
|
|
|
def __getitem__(self, idx): |
|
cumlen = 0 |
|
for k in self.datasets.keys(): |
|
if idx < cumlen + self.lens[k]: |
|
return k, self.datasets[k].__getitem__(idx - cumlen) |
|
cumlen += self.lens[k] |
|
|
|
def _offset_indices(self, batches, offset): |
|
return [[idx + offset for idx in batch] for batch in batches] |
|
|
|
def _get_batch_indices(self, per_dataset_batches): |
|
batches = [] |
|
cumlen = 0 |
|
for k in self.datasets.keys(): |
|
batches += self._offset_indices(per_dataset_batches[k], cumlen) |
|
cumlen += self.lens[k] |
|
random.shuffle(batches) |
|
return batches |
|
|
|
def get_split_batch_indices(self, toks_per_batch, extra_toks_per_seq=0, |
|
val_split=0.2): |
|
train_per_dataset_batches = OrderedDict() |
|
val_per_dataset_batches = OrderedDict() |
|
for k, d in self.datasets.items(): |
|
batches = d.get_batch_indices(toks_per_batch, extra_toks_per_seq) |
|
random.shuffle(batches) |
|
split = int(np.floor(val_split * len(batches))) |
|
train_batches, val_batches = batches[split:], batches[:split] |
|
train_per_dataset_batches[k] = train_batches |
|
val_per_dataset_batches[k] = val_batches |
|
return self._get_batch_indices(train_per_dataset_batches |
|
), self._get_batch_indices(val_per_dataset_batches) |
|
|
|
|
|
class MaskedFastaBatchedDataset(FastaBatchedDataset): |
|
""" |
|
For each sequence, mask all the mutated positions in one data entry. |
|
""" |
|
def __init__(self, sequence_labels, sequence_strs, mask_positions=None): |
|
self.sequence_labels = list(sequence_labels) |
|
self.sequence_strs = list(sequence_strs) |
|
if mask_positions is not None: |
|
self.mask_positions = list(mask_positions) |
|
|
|
@classmethod |
|
def from_file(cls, fasta_file, wt): |
|
ds = super(MaskedFastaBatchedDataset, cls).from_file(fasta_file) |
|
sequence_labels, sequence_strs = ds.sequence_labels, ds.sequence_strs |
|
mask_positions = [] |
|
for s in sequence_strs: |
|
|
|
positions = [pos+1 for pos in range(len(wt)) if s[pos] != wt[pos]] |
|
mask_positions.append(positions) |
|
return cls(sequence_labels, sequence_strs, mask_positions) |
|
|
|
@classmethod |
|
def from_dataframe(cls, df, wt): |
|
sequence_labels, sequence_strs = df.log_fitness.values, df.seq.values |
|
mask_positions = [] |
|
for s in sequence_strs: |
|
|
|
positions = [pos+1 for pos in range(len(wt)) if s[pos] != wt[pos]] |
|
mask_positions.append(positions) |
|
return cls(sequence_labels, sequence_strs, mask_positions) |
|
|
|
def __getitem__(self, idx): |
|
return self.sequence_labels[idx], self.sequence_strs[idx], self.mask_positions[idx] |
|
|
|
|
|
class PLLFastaBatchedDataset(MaskedFastaBatchedDataset): |
|
"""Batched dataset specialized for computing pseudo log likelihoods. |
|
For each sequence, mask each of the mutated positions as a data entry. |
|
""" |
|
@classmethod |
|
def from_file(cls, fasta_file, wt): |
|
ds = super(PLLFastaBatchedDataset, cls).from_file(fasta_file, wt) |
|
pll_sequence_labels = [] |
|
pll_sequence_strs = [] |
|
pll_mask_positions = [] |
|
for i in range(len(ds.sequence_strs)): |
|
s = ds.sequence_strs[i] |
|
l = ds.sequence_labels[i] |
|
m = ds.mask_positions[i] |
|
if s == wt: |
|
pll_sequence_labels.append(l) |
|
pll_sequence_strs.append(s) |
|
pll_mask_positions.append(1) |
|
continue |
|
for pos in m: |
|
pll_sequence_labels.append(l) |
|
pll_sequence_strs.append(s) |
|
pll_mask_positions.append(pos) |
|
return cls(pll_sequence_labels, pll_sequence_strs, pll_mask_positions) |
|
|
|
|
|
class MaskedBatchConverter(BatchConverter): |
|
"""Batch converter to be used with MaskedFastaBatchedDataset.""" |
|
|
|
def __call__(self, raw_batch): |
|
_raw_batch = [(l, s) for l, s, p in raw_batch] |
|
mask_pos = [p for l, s, p in raw_batch] |
|
mask_pos = torch.tensor(mask_pos).long() |
|
labels, strs, tokens = super(MaskedBatchConverter, self).__call__(_raw_batch) |
|
return labels, strs, tokens, mask_pos |
|
|
|
|
|
class PLLBatchConverter(MaskedBatchConverter): |
|
pass |
|
|
|
|
|
def random_mask_tokens(inputs, alphabet, mlm_probability=0.15): |
|
""" |
|
Prepare masked tokens inputs/labels for masked language modeling. |
|
Among the 15% masks: 80% MASK, 10% random, 10% original. |
|
""" |
|
labels = inputs.clone() |
|
device = inputs.device |
|
|
|
|
|
probability_matrix = torch.full(labels.shape, mlm_probability, |
|
device=device) |
|
special_tokens_mask = (inputs == alphabet.padding_idx) |
|
probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8, |
|
device=device)).bool() & masked_indices |
|
inputs[indices_replaced] = alphabet.mask_idx |
|
|
|
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5, |
|
device=device)).bool() & masked_indices & ~indices_replaced |
|
random_AAs = torch.randint(len(alphabet.prepend_toks), |
|
len(alphabet.standard_toks), labels.shape, |
|
dtype=torch.long, device=device) |
|
inputs[indices_random] = random_AAs[indices_random] |
|
|
|
|
|
|
|
return inputs, labels, masked_indices |