iupacGPT / iupac-gpt /scripts /classification.py
mao jiashun
Upload 58 files
295ff14
#!/usr/bin/env python3
"""Run single- or multi-task classification using pre-trained transformer decoder and
Pfeiffer adapters.
"""
import argparse
import os.path
import statistics
from pandas import read_csv
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from transformers import GPT2Config, PfeifferConfig
import sys
sys.path.append('/root/autodl-tmp/wjm/iupac-gpt')
import smiles_gpt as gpt
RANDOM_SEED = 42
REDUCTION_FACTOR, ACTIVATION = 16, "gelu"
ES_MIN_DELTA, ES_PATIENCE = 2e-3, 2
def main(options: argparse.Namespace):
print(options.tasks,options.has_empty) #['p_np'] False
options.has_empty = False
model_config = GPT2Config.from_pretrained(options.checkpoint)
model_config.num_tasks = len(options.tasks)
tokenizer_file = os.path.join(options.checkpoint, "tokenizer.json")
tokenizer = gpt.SMILESBPETokenizer.get_hf_tokenizer(
tokenizer_file, model_max_length=model_config.n_positions)
model_config.pad_token_id = tokenizer.pad_token_id
model = gpt.GPT2ForSequenceClassification.from_pretrained(
options.checkpoint, config=model_config)
adapter_config = PfeifferConfig(
original_ln_before=True, original_ln_after=True, residual_before_ln=True,
adapter_residual_before_ln=False, ln_before=False, ln_after=False,
mh_adapter=False, output_adapter=True, non_linearity=ACTIVATION,
reduction_factor=REDUCTION_FACTOR, cross_adapter=False)
adapter_name = os.path.splitext(options.csv)[0]
model.add_adapter(adapter_name, config=adapter_config)
model.train_adapter(adapter_name)
model.set_active_adapters(adapter_name)
data_frame = read_csv(options.csv)
splitter = gpt.CVSplitter(mode=options.split)
data_module = gpt.CSVDataModule(
data_frame, tokenizer, target_column=options.tasks,
has_empty_target=options.has_empty, num_workers=options.workers,
batch_size=options.batch_size, splitter=splitter)
early_stopping = EarlyStopping("val_roc", ES_MIN_DELTA, ES_PATIENCE, mode="max")
trainer = Trainer(gpus=options.device, max_epochs=options.max_epochs,
callbacks=[early_stopping])
lit_model = gpt.ClassifierLitModel(
model, num_tasks=len(options.tasks), has_empty_labels=options.has_empty,
batch_size=options.batch_size, learning_rate=options.learning_rate,
scheduler_lambda=options.scheduler_lambda, weight_decay=options.weight_decay,
scheduler_step=options.scheduler_step)
trainer.fit(lit_model, data_module)
return trainer.test(model=lit_model, datamodule=data_module)
def process_options() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=("Run single- or multi-task classification tasks using pre-trained "
"transformer decoder and Pfeiffer adapters."),
add_help=True,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("checkpoint",
help=("The directory that stores HuggingFace transformer "
"configuration and tokenizer file `tokenizer.json`."))
parser.add_argument("csv", help="CSV file with SMILES entries and task labels")
parser.add_argument("tasks", help="Task names", nargs="+")
parser.add_argument("-e", "--has_empty", help="Whether tasks contain empty values",
action="store_true")
parser.add_argument("-d", "--device", help="`0` for CPU and `1` for GPU", default=1,
choices=(0, 1), type=int)
parser.add_argument("-w", "--workers", help="# of workers", default=0, type=int)
parser.add_argument("-b", "--batch_size",
help="Train/Val/Test data loader batch size",
type=int, default=64)
parser.add_argument("-l", "--learning_rate", help="The initial learning rate",
type=float, default=7e-4)
parser.add_argument("-m", "--max_epochs", help="The maximum number of epochs",
type=int, default=15)
parser.add_argument("-c", "--weight-decay", help="AdamW optimizer weight decay",
type=float, default=0.01)
parser.add_argument("-a", "--scheduler_lambda",
help="Lambda parameter of the exponential lr scheduler",
type=float, default=0.99)
parser.add_argument("-s", "--scheduler_step",
help="Step parameter of the exponential lr scheduler",
type=float, default=10)
parser.add_argument("-p", "--split", help="Data splitter",
choices=("random", "scaffold"), default="scaffold")
parser.add_argument("-k", "--num_folds",
help="Number of CV runs w/ different random seeds",
type=int, default=10)
return parser.parse_args()
if __name__ == "__main__":
from warnings import filterwarnings
from pytorch_lightning import seed_everything
from rdkit.RDLogger import DisableLog
filterwarnings("ignore", category=UserWarning)
DisableLog("rdApp.*")
options = process_options()
prc_results, roc_results = [], []
for fold_i in range(options.num_folds):
seed_everything(seed=RANDOM_SEED + fold_i)
results = main(options)[0]
prc_results.append(results["test_prc"])
roc_results.append(results["test_roc"])
if options.num_folds > 1:
mean_roc, std_roc = statistics.mean(roc_results), statistics.stdev(roc_results)
mean_prc, std_prc = statistics.mean(prc_results), statistics.stdev(prc_results)
else:
mean_roc, std_roc = roc_results[0], 0.
mean_prc, std_prc = prc_results[0], 0.
print(f"Mean AUC-ROC: {mean_roc:.3f} (+/-{std_roc:.3f})")
print(f"Mean AUC-PRC: {mean_prc:.3f} (+/-{std_prc:.3f})")