|
import regex as re |
|
import torch |
|
import numpy as np |
|
import random |
|
import collections |
|
|
|
class Encoder(): |
|
|
|
def __init__(self, max_length=500, add_bos=True, add_eos=True, feature_size=32): |
|
self.vocab_encoder = torch.load('pubchem_canon_zinc_final_vocab_sorted_curated.pth') |
|
|
|
self.max_length = max_length |
|
self.min_length = 1 |
|
self.mod_length = 42 |
|
self.mlm_probability = .15 |
|
self.avg_length = 66 |
|
self.tail = 122 |
|
self.b0_cache=collections.deque() |
|
self.b1_cache=collections.deque() |
|
self.b2_cache=collections.deque() |
|
self.b3_cache=collections.deque() |
|
self.bucket0=collections.deque() |
|
self.bucket1=collections.deque() |
|
self.bucket2=collections.deque() |
|
self.bucket3=collections.deque() |
|
if feature_size == 32: |
|
self.b0_max=1100 |
|
self.b1_max=700 |
|
self.b2_max=150 |
|
self.b3_max=50 |
|
else: |
|
self.b0_max=1382 |
|
self.b1_max=871 |
|
self.b2_max=516 |
|
self.b3_max=311 |
|
values = list(self.vocab_encoder.values()) |
|
num_top = 0 |
|
middle_top = 0 |
|
bottom = 0 |
|
for count in values: |
|
if count > 100000: |
|
num_top += 1 |
|
if count > 50: |
|
middle_top += 1 |
|
middle_top = middle_top - num_top |
|
self.cutoffs = [num_top+4, middle_top] |
|
self.char2id = {"<bos>":0, "<eos>":1, "<pad>":2, "<mask>":3} |
|
self.id2char = {0:"<bos>", 1:"<eos>", 2:"<pad>", 3:"<mask>"} |
|
self.pad = self.char2id['<pad>'] |
|
self.mask = self.char2id['<mask>'] |
|
self.eos = self.char2id['<eos>'] |
|
self.bos = self.char2id['<bos>'] |
|
pos = 0 |
|
for key, value in self.vocab_encoder.items(): |
|
|
|
self.char2id[key] = pos+4 |
|
self.id2char[pos+4] = key |
|
pos += 1 |
|
self.char2id["<unk>"] = pos + 4 |
|
self.id2char[pos+4] = "<unk>" |
|
self.pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" |
|
self.regex = re.compile(self.pattern) |
|
self.add_bos = add_bos |
|
self.add_eos = add_eos |
|
|
|
|
|
def encode(self, char): |
|
|
|
|
|
if self.add_bos == True: |
|
char = ['<bos>'] + char |
|
if self.add_eos == True: |
|
char = char + ['<eos>'] |
|
|
|
return torch.tensor([self.char2id[word] for word in char]) |
|
|
|
def encoder(self, tokens): |
|
|
|
return [self.encode(mol) for mol in tokens] |
|
|
|
def process_text(self, text): |
|
|
|
|
|
mod_length = self.mod_length |
|
avg_length = self.avg_length |
|
for mol in text: |
|
|
|
if '\n' in mol['text']: |
|
print('carriage return in mol') |
|
raw_regex = self.regex.findall(mol['text'].strip('\n')) |
|
length = len(raw_regex) |
|
if length > self.min_length and length < mod_length: |
|
if len(self.bucket0) < self.b0_max: |
|
self.bucket0.append(raw_regex) |
|
else: |
|
self.b0_cache.append(raw_regex) |
|
elif length >= mod_length and length < avg_length: |
|
if len(self.bucket1) < self.b1_max: |
|
self.bucket1.append(raw_regex) |
|
else: |
|
self.b1_cache.append(raw_regex) |
|
elif length >= avg_length and length < self.tail: |
|
if len(self.bucket2) < self.b2_max: |
|
self.bucket2.append(raw_regex) |
|
else: |
|
self.b2_cache.append(raw_regex) |
|
elif length >= self.tail and length < self.max_length: |
|
if len(self.bucket3) < self.b3_max: |
|
self.bucket3.append(raw_regex) |
|
else: |
|
self.b3_cache.append(raw_regex) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.bucket0) < self.b0_max and len(self.b0_cache) > 0: |
|
cache_size = len(self.b0_cache) |
|
max_margin = self.b0_max-len(self.bucket0) |
|
range0 = min(cache_size, max_margin) |
|
outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + [self.b0_cache.pop() for i in range(range0)] |
|
|
|
|
|
else: |
|
outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] |
|
|
|
if len(self.bucket1) < self.b1_max and len(self.b1_cache) > 0: |
|
cache_size = len(self.b1_cache) |
|
max_margin = self.b1_max-len(self.bucket1) |
|
range1 = min(cache_size, max_margin) |
|
outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + [self.b1_cache.pop() for i in range(range1)] |
|
else: |
|
outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] |
|
|
|
if len(self.bucket2) < self.b2_max and len(self.b2_cache) > 0: |
|
cache_size = len(self.b2_cache) |
|
max_margin = self.b2_max-len(self.bucket2) |
|
range2 = min(cache_size, max_margin) |
|
outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] + [self.b2_cache.pop() for i in range(range2)] |
|
else: |
|
outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] |
|
|
|
if len(self.bucket3) < self.b3_max and len(self.b3_cache) > 0: |
|
cache_size = len(self.b3_cache) |
|
max_margin = self.b3_max-len(self.bucket3) |
|
range3 = min(cache_size, max_margin) |
|
outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] + [self.b3_cache.pop() for i in range(range3)] |
|
else: |
|
outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return outbucket0, outbucket1, outbucket2, outbucket3 |
|
|
|
def mask_tokens( self, inputs, special_tokens_mask= None): |
|
""" |
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
|
""" |
|
labels = inputs.clone() |
|
|
|
probability_matrix = torch.full(labels.size(), self.mlm_probability) |
|
if special_tokens_mask is None: |
|
special_tokens_mask = [ |
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
|
] |
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) |
|
else: |
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) |
|
|
|
|
|
|
|
probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.size(), 0.8)).bool() & masked_indices |
|
inputs[indices_replaced] = self.mask |
|
|
|
|
|
indices_random = torch.bernoulli(torch.full(labels.size(), 0.5)).bool() & masked_indices & ~indices_replaced |
|
random_words = torch.randint(len(self.char2id.keys()), labels.size(), dtype=torch.long) |
|
inputs[indices_random] = random_words[indices_random] |
|
|
|
|
|
return inputs, labels |
|
def pack_tensors(self, tokens): |
|
array_ids = self.encoder(tokens) |
|
array = torch.nn.utils.rnn.pad_sequence(array_ids, batch_first=True, padding_value=self.pad) |
|
lengths = (array!=self.pad).sum(dim=-1) |
|
|
|
special_token_mask = [list(map(lambda x: 1 if x in [self.bos, self.eos, self.pad] else 0, stuff)) for stuff in array.tolist()] |
|
masked_array, masked_labels = self.mask_tokens(array, special_token_mask) |
|
return masked_array, masked_labels, array_ids, lengths |
|
def process(self, text): |
|
arrays = [] |
|
lengths = [] |
|
targets = [] |
|
arrays_ids = [] |
|
for tokens in self.process_text(text): |
|
if len(tokens) > 0: |
|
array, target, array_ids, lgt = self.pack_tensors(tokens) |
|
arrays.append(array) |
|
targets.append(target) |
|
arrays_ids.append(array_ids) |
|
lengths.append(lgt) |
|
return arrays, targets, arrays_ids, lengths |
|
|
|
if __name__ == '__main__': |
|
|
|
text_encoder = Encoder() |
|
|