import random from multiprocessing import Pool from collections import UserList, defaultdict import numpy as np import pandas as pd from matplotlib import pyplot as plt import torch from rdkit import rdBase from rdkit import Chem # https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader def set_torch_seed_to_all_gens(_): seed = torch.initial_seed() % (2**32 - 1) random.seed(seed) np.random.seed(seed) class SpecialTokens: bos = '' eos = '' pad = '' unk = '' class CharVocab: @classmethod def from_data(cls, data, *args, **kwargs): chars = set() for string in data: chars.update(string) return cls(chars, *args, **kwargs) def __init__(self, chars, ss=SpecialTokens): if (ss.bos in chars) or (ss.eos in chars) or \ (ss.pad in chars) or (ss.unk in chars): raise ValueError('SpecialTokens in chars') all_syms = sorted(list(chars)) + [ss.bos, ss.eos, ss.pad, ss.unk] self.ss = ss self.c2i = {c: i for i, c in enumerate(all_syms)} self.i2c = {i: c for i, c in enumerate(all_syms)} def __len__(self): return len(self.c2i) @property def bos(self): return self.c2i[self.ss.bos] @property def eos(self): return self.c2i[self.ss.eos] @property def pad(self): return self.c2i[self.ss.pad] @property def unk(self): return self.c2i[self.ss.unk] def char2id(self, char): if char not in self.c2i: return self.unk return self.c2i[char] def id2char(self, id): if id not in self.i2c: return self.ss.unk return self.i2c[id] def string2ids(self, string, add_bos=False, add_eos=False): ids = [self.char2id(c) for c in string] if add_bos: ids = [self.bos] + ids if add_eos: ids = ids + [self.eos] return ids def ids2string(self, ids, rem_bos=True, rem_eos=True): if len(ids) == 0: return '' if rem_bos and ids[0] == self.bos: ids = ids[1:] if rem_eos and ids[-1] == self.eos: ids = ids[:-1] string = ''.join([self.id2char(id) for id in ids]) return string class OneHotVocab(CharVocab): def __init__(self, *args, **kwargs): super(OneHotVocab, self).__init__(*args, **kwargs) self.vectors = torch.eye(len(self.c2i)) def mapper(n_jobs): ''' Returns function for map call. If n_jobs == 1, will use standard map If n_jobs > 1, will use multiprocessing pool If n_jobs is a pool object, will return its map function ''' if n_jobs == 1: def _mapper(*args, **kwargs): return list(map(*args, **kwargs)) return _mapper if isinstance(n_jobs, int): pool = Pool(n_jobs) def _mapper(*args, **kwargs): try: result = pool.map(*args, **kwargs) finally: pool.terminate() return result return _mapper return n_jobs.map class Logger(UserList): def __init__(self, data=None): super().__init__() self.sdata = defaultdict(list) for step in (data or []): self.append(step) def __getitem__(self, key): if isinstance(key, int): return self.data[key] if isinstance(key, slice): return Logger(self.data[key]) ldata = self.sdata[key] if isinstance(ldata[0], dict): return Logger(ldata) return ldata def append(self, step_dict): super().append(step_dict) for k, v in step_dict.items(): self.sdata[k].append(v) def save(self, path): df = pd.DataFrame(list(self)) df.to_csv(path, index=None) class LogPlotter: def __init__(self, log): self.log = log def line(self, ax, name): if isinstance(self.log[0][name], dict): for k in self.log[0][name]: ax.plot(self.log[name][k], label=k) ax.legend() else: ax.plot(self.log[name]) ax.set_ylabel('value') ax.set_xlabel('epoch') ax.set_title(name) def grid(self, names, size=7): _, axs = plt.subplots(nrows=len(names) // 2, ncols=2, figsize=(size * 2, size * (len(names) // 2))) for ax, name in zip(axs.flatten(), names): self.line(ax, name) class CircularBuffer: def __init__(self, size): self.max_size = size self.data = np.zeros(self.max_size) self.size = 0 self.pointer = -1 def add(self, element): self.size = min(self.size + 1, self.max_size) self.pointer = (self.pointer + 1) % self.max_size self.data[self.pointer] = element return element def last(self): assert self.pointer != -1, "Can't get an element from an empty buffer!" return self.data[self.pointer] def mean(self): if self.size > 0: return self.data[:self.size].mean() return 0.0 def disable_rdkit_log(): rdBase.DisableLog('rdApp.*') def enable_rdkit_log(): rdBase.EnableLog('rdApp.*') def get_mol(smiles_or_mol): ''' Loads SMILES/molecule into RDKit's object ''' if isinstance(smiles_or_mol, str): if len(smiles_or_mol) == 0: return None mol = Chem.MolFromSmiles(smiles_or_mol) if mol is None: return None try: Chem.SanitizeMol(mol) except ValueError: return None return mol return smiles_or_mol class StringDataset: def __init__(self, vocab, data): """ Creates a convenient Dataset with SMILES tokinization Arguments: vocab: CharVocab instance for tokenization data (list): SMILES strings for the dataset """ self.vocab = vocab self.tokens = [vocab.string2ids(s) for s in data] self.data = data self.bos = vocab.bos self.eos = vocab.eos def __len__(self): """ Computes a number of objects in the dataset """ return len(self.tokens) def __getitem__(self, index): """ Prepares torch tensors with a given SMILES. Arguments: index (int): index of SMILES in the original dataset Returns: A tuple (with_bos, with_eos, smiles), where * with_bos is a torch.long tensor of SMILES tokens with BOS (beginning of a sentence) token * with_eos is a torch.long tensor of SMILES tokens with EOS (end of a sentence) token * smiles is an original SMILES from the dataset """ tokens = self.tokens[index] with_bos = torch.tensor([self.bos] + tokens, dtype=torch.long) with_eos = torch.tensor(tokens + [self.eos], dtype=torch.long) return with_bos, with_eos, self.data[index] def default_collate(self, batch, return_data=False): """ Simple collate function for SMILES dataset. Joins a batch of objects from StringDataset into a batch Arguments: batch: list of objects from StringDataset pad: padding symbol, usually equals to vocab.pad return_data: if True, will return SMILES used in a batch Returns: with_bos, with_eos, lengths [, data] where * with_bos: padded sequence with BOS in the beginning * with_eos: padded sequence with EOS in the end * lengths: array with SMILES lengths in the batch * data: SMILES in the batch Note: output batch is sorted with respect to SMILES lengths in decreasing order, since this is a default format for torch RNN implementations """ with_bos, with_eos, data = list(zip(*batch)) lengths = [len(x) for x in with_bos] order = np.argsort(lengths)[::-1] with_bos = [with_bos[i] for i in order] with_eos = [with_eos[i] for i in order] lengths = [lengths[i] for i in order] with_bos = torch.nn.utils.rnn.pad_sequence( with_bos, padding_value=self.vocab.pad ) with_eos = torch.nn.utils.rnn.pad_sequence( with_eos, padding_value=self.vocab.pad ) if return_data: data = np.array(data)[order] return with_bos, with_eos, lengths, data return with_bos, with_eos, lengths def batch_to_device(batch, device): return [ x.to(device) if isinstance(x, torch.Tensor) else x for x in batch ]