gzhong's picture
Upload folder using huggingface_hub
7718235 verified
from collections import OrderedDict
import random
import numpy as np
import torch
import pandas as pd
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel
from esm import BatchConverter, pretrained
class RandomCropBatchConverter(BatchConverter):
"""Callable to convert an unprocessed (labels + strings) batch to a
processed (labels + tensor) batch.
For sequences over max_len, randomly crop a window.
"""
def __init__(self, alphabet, max_len):
super(RandomCropBatchConverter, self).__init__(alphabet)
self.max_len = max_len
def __call__(self, raw_batch):
cropped_batch = [(label, self._crop(seq)) for label, seq in raw_batch]
return super(RandomCropBatchConverter, self).__call__(cropped_batch)
def _crop(self, seq):
if len(seq) <= self.max_len:
return seq
start_idx = np.random.choice(len(seq) - self.max_len + 1)
return seq[start_idx: start_idx+self.max_len]
class CosupRandomCropBatchConverter(RandomCropBatchConverter):
def __call__(self, raw_batch):
datatype = [tup[0] for tup in raw_batch]
assert len(np.unique(datatype)) == 1
datatype = datatype[0]
raw_batch = [tup[1] for tup in raw_batch]
return datatype, super(CosupRandomCropBatchConverter,
self).__call__(raw_batch)
class CSVBatchedDataset(FastaBatchedDataset):
@classmethod
def from_file(cls, csv_file):
df = pd.read_csv(csv_file)
return cls(df.log_fitness.values, df.seq.values)
@classmethod
def from_dataframe(cls, df):
return cls(df.log_fitness.values, df.seq.values)
class CosupervisionDataset(object):
def __init__(self, fasta_batched_dataset, csv_batched_dataset):
self.datasets = OrderedDict()
self.datasets['unsup'] = fasta_batched_dataset
self.datasets['sup'] = csv_batched_dataset
self.lens = OrderedDict()
for k, d in self.datasets.items():
self.lens[k] = len(d)
def __len__(self):
return sum(self.lens.values())
def __getitem__(self, idx):
cumlen = 0
for k in self.datasets.keys():
if idx < cumlen + self.lens[k]:
return k, self.datasets[k].__getitem__(idx - cumlen)
cumlen += self.lens[k]
def _offset_indices(self, batches, offset):
return [[idx + offset for idx in batch] for batch in batches]
def _get_batch_indices(self, per_dataset_batches):
batches = []
cumlen = 0
for k in self.datasets.keys():
batches += self._offset_indices(per_dataset_batches[k], cumlen)
cumlen += self.lens[k]
random.shuffle(batches)
return batches
def get_split_batch_indices(self, toks_per_batch, extra_toks_per_seq=0,
val_split=0.2):
train_per_dataset_batches = OrderedDict()
val_per_dataset_batches = OrderedDict()
for k, d in self.datasets.items():
batches = d.get_batch_indices(toks_per_batch, extra_toks_per_seq)
random.shuffle(batches)
split = int(np.floor(val_split * len(batches)))
train_batches, val_batches = batches[split:], batches[:split]
train_per_dataset_batches[k] = train_batches
val_per_dataset_batches[k] = val_batches
return self._get_batch_indices(train_per_dataset_batches
), self._get_batch_indices(val_per_dataset_batches)
class MaskedFastaBatchedDataset(FastaBatchedDataset):
"""
For each sequence, mask all the mutated positions in one data entry.
"""
def __init__(self, sequence_labels, sequence_strs, mask_positions=None):
self.sequence_labels = list(sequence_labels)
self.sequence_strs = list(sequence_strs)
if mask_positions is not None:
self.mask_positions = list(mask_positions)
@classmethod
def from_file(cls, fasta_file, wt):
ds = super(MaskedFastaBatchedDataset, cls).from_file(fasta_file)
sequence_labels, sequence_strs = ds.sequence_labels, ds.sequence_strs
mask_positions = []
for s in sequence_strs:
# +1 for start token
positions = [pos+1 for pos in range(len(wt)) if s[pos] != wt[pos]]
mask_positions.append(positions)
return cls(sequence_labels, sequence_strs, mask_positions)
@classmethod
def from_dataframe(cls, df, wt):
sequence_labels, sequence_strs = df.log_fitness.values, df.seq.values
mask_positions = []
for s in sequence_strs:
# +1 for start token
positions = [pos+1 for pos in range(len(wt)) if s[pos] != wt[pos]]
mask_positions.append(positions)
return cls(sequence_labels, sequence_strs, mask_positions)
def __getitem__(self, idx):
return self.sequence_labels[idx], self.sequence_strs[idx], self.mask_positions[idx]
class PLLFastaBatchedDataset(MaskedFastaBatchedDataset):
"""Batched dataset specialized for computing pseudo log likelihoods.
For each sequence, mask each of the mutated positions as a data entry.
"""
@classmethod
def from_file(cls, fasta_file, wt):
ds = super(PLLFastaBatchedDataset, cls).from_file(fasta_file, wt)
pll_sequence_labels = []
pll_sequence_strs = []
pll_mask_positions = []
for i in range(len(ds.sequence_strs)):
s = ds.sequence_strs[i]
l = ds.sequence_labels[i]
m = ds.mask_positions[i]
if s == wt:
pll_sequence_labels.append(l)
pll_sequence_strs.append(s)
pll_mask_positions.append(1) # arbitrary choice of pos
continue
for pos in m:
pll_sequence_labels.append(l)
pll_sequence_strs.append(s)
pll_mask_positions.append(pos)
return cls(pll_sequence_labels, pll_sequence_strs, pll_mask_positions)
class MaskedBatchConverter(BatchConverter):
"""Batch converter to be used with MaskedFastaBatchedDataset."""
def __call__(self, raw_batch):
_raw_batch = [(l, s) for l, s, p in raw_batch]
mask_pos = [p for l, s, p in raw_batch]
mask_pos = torch.tensor(mask_pos).long()
labels, strs, tokens = super(MaskedBatchConverter, self).__call__(_raw_batch)
return labels, strs, tokens, mask_pos
class PLLBatchConverter(MaskedBatchConverter):
pass
def random_mask_tokens(inputs, alphabet, mlm_probability=0.15):
"""
Prepare masked tokens inputs/labels for masked language modeling.
Among the 15% masks: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
device = inputs.device
# We sample a few tokens in each sequence for MLM training
# (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, mlm_probability,
device=device)
special_tokens_mask = (inputs == alphabet.padding_idx)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8,
device=device)).bool() & masked_indices
inputs[indices_replaced] = alphabet.mask_idx
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5,
device=device)).bool() & masked_indices & ~indices_replaced
random_AAs = torch.randint(len(alphabet.prepend_toks),
len(alphabet.standard_toks), labels.shape,
dtype=torch.long, device=device)
inputs[indices_random] = random_AAs[indices_random]
# The rest of the time (10% of the time)
# we keep the masked input tokens unchanged
return inputs, labels, masked_indices