Spaces:
Runtime error
Runtime error
import random | |
from multiprocessing import Pool | |
from collections import UserList, defaultdict | |
import numpy as np | |
import pandas as pd | |
from matplotlib import pyplot as plt | |
import torch | |
from rdkit import rdBase | |
from rdkit import Chem | |
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader | |
def set_torch_seed_to_all_gens(_): | |
seed = torch.initial_seed() % (2**32 - 1) | |
random.seed(seed) | |
np.random.seed(seed) | |
class SpecialTokens: | |
bos = '<bos>' | |
eos = '<eos>' | |
pad = '<pad>' | |
unk = '<unk>' | |
class CharVocab: | |
def from_data(cls, data, *args, **kwargs): | |
chars = set() | |
for string in data: | |
chars.update(string) | |
return cls(chars, *args, **kwargs) | |
def __init__(self, chars, ss=SpecialTokens): | |
if (ss.bos in chars) or (ss.eos in chars) or \ | |
(ss.pad in chars) or (ss.unk in chars): | |
raise ValueError('SpecialTokens in chars') | |
all_syms = sorted(list(chars)) + [ss.bos, ss.eos, ss.pad, ss.unk] | |
self.ss = ss | |
self.c2i = {c: i for i, c in enumerate(all_syms)} | |
self.i2c = {i: c for i, c in enumerate(all_syms)} | |
def __len__(self): | |
return len(self.c2i) | |
def bos(self): | |
return self.c2i[self.ss.bos] | |
def eos(self): | |
return self.c2i[self.ss.eos] | |
def pad(self): | |
return self.c2i[self.ss.pad] | |
def unk(self): | |
return self.c2i[self.ss.unk] | |
def char2id(self, char): | |
if char not in self.c2i: | |
return self.unk | |
return self.c2i[char] | |
def id2char(self, id): | |
if id not in self.i2c: | |
return self.ss.unk | |
return self.i2c[id] | |
def string2ids(self, string, add_bos=False, add_eos=False): | |
ids = [self.char2id(c) for c in string] | |
if add_bos: | |
ids = [self.bos] + ids | |
if add_eos: | |
ids = ids + [self.eos] | |
return ids | |
def ids2string(self, ids, rem_bos=True, rem_eos=True): | |
if len(ids) == 0: | |
return '' | |
if rem_bos and ids[0] == self.bos: | |
ids = ids[1:] | |
if rem_eos and ids[-1] == self.eos: | |
ids = ids[:-1] | |
string = ''.join([self.id2char(id) for id in ids]) | |
return string | |
class OneHotVocab(CharVocab): | |
def __init__(self, *args, **kwargs): | |
super(OneHotVocab, self).__init__(*args, **kwargs) | |
self.vectors = torch.eye(len(self.c2i)) | |
def mapper(n_jobs): | |
''' | |
Returns function for map call. | |
If n_jobs == 1, will use standard map | |
If n_jobs > 1, will use multiprocessing pool | |
If n_jobs is a pool object, will return its map function | |
''' | |
if n_jobs == 1: | |
def _mapper(*args, **kwargs): | |
return list(map(*args, **kwargs)) | |
return _mapper | |
if isinstance(n_jobs, int): | |
pool = Pool(n_jobs) | |
def _mapper(*args, **kwargs): | |
try: | |
result = pool.map(*args, **kwargs) | |
finally: | |
pool.terminate() | |
return result | |
return _mapper | |
return n_jobs.map | |
class Logger(UserList): | |
def __init__(self, data=None): | |
super().__init__() | |
self.sdata = defaultdict(list) | |
for step in (data or []): | |
self.append(step) | |
def __getitem__(self, key): | |
if isinstance(key, int): | |
return self.data[key] | |
if isinstance(key, slice): | |
return Logger(self.data[key]) | |
ldata = self.sdata[key] | |
if isinstance(ldata[0], dict): | |
return Logger(ldata) | |
return ldata | |
def append(self, step_dict): | |
super().append(step_dict) | |
for k, v in step_dict.items(): | |
self.sdata[k].append(v) | |
def save(self, path): | |
df = pd.DataFrame(list(self)) | |
df.to_csv(path, index=None) | |
class LogPlotter: | |
def __init__(self, log): | |
self.log = log | |
def line(self, ax, name): | |
if isinstance(self.log[0][name], dict): | |
for k in self.log[0][name]: | |
ax.plot(self.log[name][k], label=k) | |
ax.legend() | |
else: | |
ax.plot(self.log[name]) | |
ax.set_ylabel('value') | |
ax.set_xlabel('epoch') | |
ax.set_title(name) | |
def grid(self, names, size=7): | |
_, axs = plt.subplots(nrows=len(names) // 2, ncols=2, | |
figsize=(size * 2, size * (len(names) // 2))) | |
for ax, name in zip(axs.flatten(), names): | |
self.line(ax, name) | |
class CircularBuffer: | |
def __init__(self, size): | |
self.max_size = size | |
self.data = np.zeros(self.max_size) | |
self.size = 0 | |
self.pointer = -1 | |
def add(self, element): | |
self.size = min(self.size + 1, self.max_size) | |
self.pointer = (self.pointer + 1) % self.max_size | |
self.data[self.pointer] = element | |
return element | |
def last(self): | |
assert self.pointer != -1, "Can't get an element from an empty buffer!" | |
return self.data[self.pointer] | |
def mean(self): | |
if self.size > 0: | |
return self.data[:self.size].mean() | |
return 0.0 | |
def disable_rdkit_log(): | |
rdBase.DisableLog('rdApp.*') | |
def enable_rdkit_log(): | |
rdBase.EnableLog('rdApp.*') | |
def get_mol(smiles_or_mol): | |
''' | |
Loads SMILES/molecule into RDKit's object | |
''' | |
if isinstance(smiles_or_mol, str): | |
if len(smiles_or_mol) == 0: | |
return None | |
mol = Chem.MolFromSmiles(smiles_or_mol) | |
if mol is None: | |
return None | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return mol | |
return smiles_or_mol | |
class StringDataset: | |
def __init__(self, vocab, data): | |
""" | |
Creates a convenient Dataset with SMILES tokinization | |
Arguments: | |
vocab: CharVocab instance for tokenization | |
data (list): SMILES strings for the dataset | |
""" | |
self.vocab = vocab | |
self.tokens = [vocab.string2ids(s) for s in data] | |
self.data = data | |
self.bos = vocab.bos | |
self.eos = vocab.eos | |
def __len__(self): | |
""" | |
Computes a number of objects in the dataset | |
""" | |
return len(self.tokens) | |
def __getitem__(self, index): | |
""" | |
Prepares torch tensors with a given SMILES. | |
Arguments: | |
index (int): index of SMILES in the original dataset | |
Returns: | |
A tuple (with_bos, with_eos, smiles), where | |
* with_bos is a torch.long tensor of SMILES tokens with | |
BOS (beginning of a sentence) token | |
* with_eos is a torch.long tensor of SMILES tokens with | |
EOS (end of a sentence) token | |
* smiles is an original SMILES from the dataset | |
""" | |
tokens = self.tokens[index] | |
with_bos = torch.tensor([self.bos] + tokens, dtype=torch.long) | |
with_eos = torch.tensor(tokens + [self.eos], dtype=torch.long) | |
return with_bos, with_eos, self.data[index] | |
def default_collate(self, batch, return_data=False): | |
""" | |
Simple collate function for SMILES dataset. Joins a | |
batch of objects from StringDataset into a batch | |
Arguments: | |
batch: list of objects from StringDataset | |
pad: padding symbol, usually equals to vocab.pad | |
return_data: if True, will return SMILES used in a batch | |
Returns: | |
with_bos, with_eos, lengths [, data] where | |
* with_bos: padded sequence with BOS in the beginning | |
* with_eos: padded sequence with EOS in the end | |
* lengths: array with SMILES lengths in the batch | |
* data: SMILES in the batch | |
Note: output batch is sorted with respect to SMILES lengths in | |
decreasing order, since this is a default format for torch | |
RNN implementations | |
""" | |
with_bos, with_eos, data = list(zip(*batch)) | |
lengths = [len(x) for x in with_bos] | |
order = np.argsort(lengths)[::-1] | |
with_bos = [with_bos[i] for i in order] | |
with_eos = [with_eos[i] for i in order] | |
lengths = [lengths[i] for i in order] | |
with_bos = torch.nn.utils.rnn.pad_sequence( | |
with_bos, padding_value=self.vocab.pad | |
) | |
with_eos = torch.nn.utils.rnn.pad_sequence( | |
with_eos, padding_value=self.vocab.pad | |
) | |
if return_data: | |
data = np.array(data)[order] | |
return with_bos, with_eos, lengths, data | |
return with_bos, with_eos, lengths | |
def batch_to_device(batch, device): | |
return [ | |
x.to(device) if isinstance(x, torch.Tensor) else x | |
for x in batch | |
] | |