File size: 4,101 Bytes
02e480f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# Deep learning
import torch
# Data
from pubchem_encoder import Encoder
from datasets import load_dataset
# Standard library
import os
import getpass
import glob
class MoleculeModule:
def __init__(self, max_len, dataset, data_path):
super().__init__()
self.dataset = dataset
self.data_path = data_path
self.text_encoder = Encoder(max_len)
def prepare_data(self):
pass
def get_vocab(self):
#using home made tokenizer, should look into existing tokenizer
return self.text_encoder.char2id
def get_cache(self):
return self.cache_files
def setup(self, stage=None):
#using huggingface dataloader
# create cache in tmp directory of locale mabchine under the current users name to prevent locking issues
pubchem_path = {'train': self.data_path}
if 'canonical' in pubchem_path['train'].lower():
pubchem_script = './pubchem_canon_script.py'
else:
pubchem_script = './pubchem_script.py'
zinc_path = './data/ZINC'
global dataset_dict
if 'ZINC' in self.dataset or 'zinc' in self.dataset:
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))]
for zfile in zinc_files:
print(zfile)
self.dataset = {'train': zinc_files}
dataset_dict = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True)
elif 'pubchem' in self.dataset:
dataset_dict = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'), split='train')
elif 'both' in self.dataset or 'Both' in self.dataset or 'BOTH' in self.dataset:
dataset_dict_pubchem = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'),split='train', trust_remote_code=True)
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))]
for zfile in zinc_files:
print(zfile)
self.dataset = {'train': zinc_files}
dataset_dict_zinc = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True)
dataset_dict = concatenate_datasets([dataset_dict_zinc, dataset_dict_pubchem])
self.pubchem= dataset_dict
print(dataset_dict.cache_files)
self.cache_files = []
for cache in dataset_dict.cache_files:
tmp = '/'.join(cache['filename'].split('/')[:4])
self.cache_files.append(tmp)
def get_optim_groups(module):
# setup optimizer
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear,)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in module.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in module.named_parameters()}
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
return optim_groups |