File size: 5,934 Bytes
295ff14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
#!/usr/bin/env python3
"""Run single- or multi-task classification using pre-trained transformer decoder and
Pfeiffer adapters.
"""
import argparse
import os.path
import statistics
from pandas import read_csv
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from transformers import GPT2Config, PfeifferConfig
import sys
sys.path.append('/root/autodl-tmp/wjm/iupac-gpt')
import smiles_gpt as gpt
RANDOM_SEED = 42
REDUCTION_FACTOR, ACTIVATION = 16, "gelu"
ES_MIN_DELTA, ES_PATIENCE = 2e-3, 2
def main(options: argparse.Namespace):
print(options.tasks,options.has_empty) #['p_np'] False
options.has_empty = False
model_config = GPT2Config.from_pretrained(options.checkpoint)
model_config.num_tasks = len(options.tasks)
tokenizer_file = os.path.join(options.checkpoint, "tokenizer.json")
tokenizer = gpt.SMILESBPETokenizer.get_hf_tokenizer(
tokenizer_file, model_max_length=model_config.n_positions)
model_config.pad_token_id = tokenizer.pad_token_id
model = gpt.GPT2ForSequenceClassification.from_pretrained(
options.checkpoint, config=model_config)
adapter_config = PfeifferConfig(
original_ln_before=True, original_ln_after=True, residual_before_ln=True,
adapter_residual_before_ln=False, ln_before=False, ln_after=False,
mh_adapter=False, output_adapter=True, non_linearity=ACTIVATION,
reduction_factor=REDUCTION_FACTOR, cross_adapter=False)
adapter_name = os.path.splitext(options.csv)[0]
model.add_adapter(adapter_name, config=adapter_config)
model.train_adapter(adapter_name)
model.set_active_adapters(adapter_name)
data_frame = read_csv(options.csv)
splitter = gpt.CVSplitter(mode=options.split)
data_module = gpt.CSVDataModule(
data_frame, tokenizer, target_column=options.tasks,
has_empty_target=options.has_empty, num_workers=options.workers,
batch_size=options.batch_size, splitter=splitter)
early_stopping = EarlyStopping("val_roc", ES_MIN_DELTA, ES_PATIENCE, mode="max")
trainer = Trainer(gpus=options.device, max_epochs=options.max_epochs,
callbacks=[early_stopping])
lit_model = gpt.ClassifierLitModel(
model, num_tasks=len(options.tasks), has_empty_labels=options.has_empty,
batch_size=options.batch_size, learning_rate=options.learning_rate,
scheduler_lambda=options.scheduler_lambda, weight_decay=options.weight_decay,
scheduler_step=options.scheduler_step)
trainer.fit(lit_model, data_module)
return trainer.test(model=lit_model, datamodule=data_module)
def process_options() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=("Run single- or multi-task classification tasks using pre-trained "
"transformer decoder and Pfeiffer adapters."),
add_help=True,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("checkpoint",
help=("The directory that stores HuggingFace transformer "
"configuration and tokenizer file `tokenizer.json`."))
parser.add_argument("csv", help="CSV file with SMILES entries and task labels")
parser.add_argument("tasks", help="Task names", nargs="+")
parser.add_argument("-e", "--has_empty", help="Whether tasks contain empty values",
action="store_true")
parser.add_argument("-d", "--device", help="`0` for CPU and `1` for GPU", default=1,
choices=(0, 1), type=int)
parser.add_argument("-w", "--workers", help="# of workers", default=0, type=int)
parser.add_argument("-b", "--batch_size",
help="Train/Val/Test data loader batch size",
type=int, default=64)
parser.add_argument("-l", "--learning_rate", help="The initial learning rate",
type=float, default=7e-4)
parser.add_argument("-m", "--max_epochs", help="The maximum number of epochs",
type=int, default=15)
parser.add_argument("-c", "--weight-decay", help="AdamW optimizer weight decay",
type=float, default=0.01)
parser.add_argument("-a", "--scheduler_lambda",
help="Lambda parameter of the exponential lr scheduler",
type=float, default=0.99)
parser.add_argument("-s", "--scheduler_step",
help="Step parameter of the exponential lr scheduler",
type=float, default=10)
parser.add_argument("-p", "--split", help="Data splitter",
choices=("random", "scaffold"), default="scaffold")
parser.add_argument("-k", "--num_folds",
help="Number of CV runs w/ different random seeds",
type=int, default=10)
return parser.parse_args()
if __name__ == "__main__":
from warnings import filterwarnings
from pytorch_lightning import seed_everything
from rdkit.RDLogger import DisableLog
filterwarnings("ignore", category=UserWarning)
DisableLog("rdApp.*")
options = process_options()
prc_results, roc_results = [], []
for fold_i in range(options.num_folds):
seed_everything(seed=RANDOM_SEED + fold_i)
results = main(options)[0]
prc_results.append(results["test_prc"])
roc_results.append(results["test_roc"])
if options.num_folds > 1:
mean_roc, std_roc = statistics.mean(roc_results), statistics.stdev(roc_results)
mean_prc, std_prc = statistics.mean(prc_results), statistics.stdev(prc_results)
else:
mean_roc, std_roc = roc_results[0], 0.
mean_prc, std_prc = prc_results[0], 0.
print(f"Mean AUC-ROC: {mean_roc:.3f} (+/-{std_roc:.3f})")
print(f"Mean AUC-PRC: {mean_prc:.3f} (+/-{std_prc:.3f})")
|