Spaces:
No application file
No application file
import models | |
import argparse, os | |
from shutil import copy2 as copy | |
from modules.config import find_all_config | |
OVERRIDE_RUN_MODE = {"serve": "infer", "debug": "eval"} | |
def check_valid_file(path): | |
if(os.path.isfile(path)): | |
return path | |
else: | |
raise argparse.ArgumentError("This path {:s} is not a valid file, check again.".format(path)) | |
def create_torchscript_model(model, model_dir, model_name): | |
"""Create a torchscript model using junk data. NOTE: same as tensorflow, is a limited model with no native python script.""" | |
import torch | |
junk_input = torch.rand(2, 10) | |
junk_output = torch.rand(2, 7) | |
traced_model = torch.jit.trace(model, junk_input, junk_output) | |
save_location = os.path.join(model_dir, model_name) | |
traced_model.save(save_location) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Main argument parser") | |
parser.add_argument("run_mode", choices=("train", "eval", "infer", "debug", "serve"), help="Main running mode of the program") | |
parser.add_argument("--model", type=str, choices=models.AvailableModels.keys(), help="The type of model to be ran") | |
parser.add_argument("--model_dir", type=str, required=True, help="Location of model") | |
parser.add_argument("--config", type=str, nargs="+", default=None, help="Location of the config file") | |
parser.add_argument("--no_keeping_config", action="store_false", help="If set, do not copy the config file to the model directory") | |
# arguments for inference | |
parser.add_argument("--features_file", type=str, help="Inference mode: Provide the location of features file") | |
parser.add_argument("--predictions_file", type=str, help="Inference mode: Provide Location of output file which is predicted from features file") | |
parser.add_argument("--src_lang", type=str, help="Inference mode: Provide language used by source file") | |
parser.add_argument("--trg_lang", type=str, default=None, help="Inference mode: Choose language that is translated from source file. NOTE: only specify for multilingual model") | |
parser.add_argument("--infer_batch_size", type=int, default=None, help="Specify the batch_size to run the model with. Default use the config value.") | |
parser.add_argument("--checkpoint", type=str, default=None, help="All mode: specify to load the checkpoint into model.") | |
parser.add_argument("--checkpoint_idx", type=int, default=0, help="All mode: specify the epoch of the checkpoint loaded. Only useful for training.") | |
parser.add_argument("--serve_path", type=str, default=None, help="File to save TorchScript model into.") | |
args = parser.parse_args() | |
# create directory if not exist | |
os.makedirs(args.model_dir, exist_ok=True) | |
config_path = args.config | |
if(config_path is None): | |
config_path = find_all_config(args.model_dir) | |
print("Config path not specified, load the configs in model directory which is {}".format(config_path)) | |
elif(args.no_keeping_config): | |
# store false variable, mean true is default | |
print("Config specified, copying all to model dir") | |
for subpath in config_path: | |
copy(subpath, args.model_dir) | |
# load model. Specific run mode required converting | |
run_mode = OVERRIDE_RUN_MODE.get(args.run_mode, args.run_mode) | |
model = models.AvailableModels[args.model](config=config_path, model_dir=args.model_dir, mode=run_mode) | |
model.load_checkpoint(args.model_dir, checkpoint=args.checkpoint, checkpoint_idx=args.checkpoint_idx) | |
# run model | |
run_mode = args.run_mode | |
if(run_mode == "train"): | |
model.run_train(model_dir=args.model_dir, config=config_path) | |
elif(run_mode == "eval"): | |
model.run_eval(model_dir=args.model_dir, config=config_path) | |
elif(run_mode == "infer"): | |
model.run_infer(args.features_file, args.predictions_file, src_lang=args.src_lang, trg_lang=args.trg_lang, config=config_path, batch_size=args.infer_batch_size) | |
elif(run_mode == "debug"): | |
raise NotImplementedError | |
model.run_debug(model_dir=args.model_dir, config=config_path) | |
elif(run_mode == "serve"): | |
if(args.serve_path is None): | |
raise parser.ArgumentError("In serving, --serve_path cannot be empty") | |
model.prepare_serve(args.serve_path, model_dir=args.model_dir, config=config_path) | |
else: | |
raise ValueError("Run mode {:s} not implemented.".format(run_mode)) | |