Spaces:
Sleeping
Sleeping
import pytorch_lightning as lightning | |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
import warnings | |
import yaml | |
import argparse | |
import os | |
import torch | |
from marcai.pl import MARCDataModule, SimilarityVectorModel | |
from marcai.utils import load_config | |
import tarfile | |
def train(name=None): | |
config_path = "config.yaml" | |
config = load_config(config_path) | |
model_config = load_config(config_path)["model"] | |
# Create data module from processed data | |
warnings.filterwarnings("ignore", ".*does not have many workers.*") | |
data = MARCDataModule( | |
model_config["train_processed_path"], | |
model_config["val_processed_path"], | |
model_config["test_processed_path"], | |
model_config["features"], | |
model_config["batch_size"], | |
) | |
# Create model | |
model = SimilarityVectorModel( | |
model_config["lr"], | |
model_config["weight_decay"], | |
model_config["optimizer"], | |
model_config["batch_size"], | |
model_config["features"], | |
model_config["hidden_sizes"], | |
) | |
save_dir = os.path.join(model_config["saved_models_dir"], name) | |
os.makedirs(save_dir, exist_ok=True) | |
# Save best models | |
checkpoint_callback = ModelCheckpoint( | |
monitor="val_acc", mode="max", dirpath=save_dir, filename="model" | |
) | |
callbacks = [checkpoint_callback] | |
if model_config["patience"] != -1: | |
early_stop_callback = EarlyStopping( | |
monitor="val_acc", | |
min_delta=0.00, | |
patience=model_config["patience"], | |
verbose=False, | |
mode="max", | |
) | |
callbacks.append(early_stop_callback) | |
trainer = lightning.Trainer( | |
max_epochs=model_config["max_epochs"], callbacks=callbacks, accelerator="cpu" | |
) | |
trainer.fit(model, data) | |
# Save ONNX | |
onnx_path = os.path.join(save_dir, "model.onnx") | |
input_sample = torch.randn((1, len(model.attrs))) | |
torch.onnx.export( | |
model, | |
input_sample, | |
onnx_path, | |
export_params=True, | |
do_constant_folding=True, | |
input_names=["input"], | |
output_names=["output"], | |
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, | |
) | |
# Save config | |
config_filename = os.path.join(save_dir, "config.yaml") | |
with open(config_filename, "w") as f: | |
dump = yaml.dump(config) | |
f.write(dump) | |
# Compress model directory files | |
tar_path = f"{save_dir}/{name}.tar.gz" | |
with tarfile.open(tar_path, mode="w:gz") as archive: | |
archive.add(save_dir, arcname=os.path.basename(save_dir)) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-n", "--run-name", help="Name for training run" | |
) | |
args = parser.parse_args() | |
train(args.run_name) | |
if __name__ == "__main__": | |
main() | |