|
|
|
import torch |
|
|
|
|
|
from pubchem_encoder import Encoder |
|
from datasets import load_dataset |
|
|
|
|
|
import os |
|
import getpass |
|
import glob |
|
|
|
|
|
class MoleculeModule: |
|
def __init__(self, max_len, dataset, data_path): |
|
super().__init__() |
|
self.dataset = dataset |
|
self.data_path = data_path |
|
self.text_encoder = Encoder(max_len) |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def get_vocab(self): |
|
|
|
return self.text_encoder.char2id |
|
|
|
def get_cache(self): |
|
return self.cache_files |
|
|
|
def setup(self, stage=None): |
|
|
|
|
|
pubchem_path = {'train': self.data_path} |
|
if 'canonical' in pubchem_path['train'].lower(): |
|
pubchem_script = './pubchem_canon_script.py' |
|
else: |
|
pubchem_script = './pubchem_script.py' |
|
zinc_path = './data/ZINC' |
|
global dataset_dict |
|
if 'ZINC' in self.dataset or 'zinc' in self.dataset: |
|
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] |
|
for zfile in zinc_files: |
|
print(zfile) |
|
self.dataset = {'train': zinc_files} |
|
dataset_dict = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) |
|
|
|
elif 'pubchem' in self.dataset: |
|
dataset_dict = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'), split='train') |
|
elif 'both' in self.dataset or 'Both' in self.dataset or 'BOTH' in self.dataset: |
|
dataset_dict_pubchem = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'),split='train', trust_remote_code=True) |
|
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] |
|
for zfile in zinc_files: |
|
print(zfile) |
|
self.dataset = {'train': zinc_files} |
|
dataset_dict_zinc = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) |
|
dataset_dict = concatenate_datasets([dataset_dict_zinc, dataset_dict_pubchem]) |
|
self.pubchem= dataset_dict |
|
print(dataset_dict.cache_files) |
|
self.cache_files = [] |
|
|
|
for cache in dataset_dict.cache_files: |
|
tmp = '/'.join(cache['filename'].split('/')[:4]) |
|
self.cache_files.append(tmp) |
|
|
|
|
|
def get_optim_groups(module): |
|
|
|
|
|
decay = set() |
|
no_decay = set() |
|
whitelist_weight_modules = (torch.nn.Linear,) |
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) |
|
for mn, m in module.named_modules(): |
|
for pn, p in m.named_parameters(): |
|
fpn = '%s.%s' % (mn, pn) if mn else pn |
|
if pn.endswith('bias'): |
|
|
|
no_decay.add(fpn) |
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): |
|
|
|
decay.add(fpn) |
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): |
|
|
|
no_decay.add(fpn) |
|
|
|
|
|
param_dict = {pn: p for pn, p in module.named_parameters()} |
|
|
|
|
|
optim_groups = [ |
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0}, |
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, |
|
] |
|
|
|
return optim_groups |