materials.smi-ted / smi-ted /finetune /finetune_classification.py
eduardosoares99's picture
Upload 159 files
02e480f verified
raw
history blame
1.86 kB
# Deep learning
import torch
import torch.nn as nn
from torch import optim
from trainers import TrainerClassifier
from utils import get_optim_groups
# Data
import pandas as pd
import numpy as np
# Standard library
import args
import os
def main(config):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# load dataset
df_train = pd.read_csv(f"{config.data_root}/train.csv")
df_valid = pd.read_csv(f"{config.data_root}/valid.csv")
df_test = pd.read_csv(f"{config.data_root}/test.csv")
# load model
if config.smi_ted_version == 'v1':
from smi_ted_light.load import load_smi_ted
elif config.smi_ted_version == 'v2':
from smi_ted_large.load import load_smi_ted
model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output)
model.net.apply(model._init_weights)
print(model.net)
lr = config.lr_start*config.lr_multiplier
optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder))
if config.loss_fn == 'crossentropy':
loss_function = nn.CrossEntropyLoss()
# init trainer
trainer = TrainerClassifier(
raw_data=(df_train, df_valid, df_test),
dataset_name=config.dataset_name,
target=config.measure_name,
batch_size=config.n_batch,
hparams=config,
target_metric=config.target_metric,
seed=config.start_seed,
checkpoints_folder=config.checkpoints_folder,
device=device,
save_ckpt=bool(config.save_ckpt)
)
trainer.compile(
model=model,
optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)),
loss_fn=loss_function
)
trainer.fit(max_epochs=config.max_epochs)
if __name__ == '__main__':
parser = args.get_parser()
config = parser.parse_args()
main(config)