File size: 4,265 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import models
import argparse, os
from shutil import copy2 as copy
from modules.config import find_all_config

OVERRIDE_RUN_MODE = {"serve": "infer", "debug": "eval"}

def check_valid_file(path):
  if(os.path.isfile(path)):
    return path
  else:
    raise argparse.ArgumentError("This path {:s} is not a valid file, check again.".format(path))

def create_torchscript_model(model, model_dir, model_name):
  """Create a torchscript model using junk data. NOTE: same as tensorflow, is a limited model with no native python script."""
  import torch
  junk_input = torch.rand(2, 10)
  junk_output = torch.rand(2, 7)
  traced_model = torch.jit.trace(model, junk_input, junk_output)
  save_location = os.path.join(model_dir, model_name)
  traced_model.save(save_location)

if __name__ == "__main__":
  parser = argparse.ArgumentParser(description="Main argument parser")
  parser.add_argument("run_mode", choices=("train", "eval", "infer", "debug", "serve"), help="Main running mode of the program")
  parser.add_argument("--model", type=str, choices=models.AvailableModels.keys(), help="The type of model to be ran")
  parser.add_argument("--model_dir", type=str, required=True, help="Location of model")
  parser.add_argument("--config", type=str, nargs="+", default=None, help="Location of the config file")
  parser.add_argument("--no_keeping_config", action="store_false", help="If set, do not copy the config file to the model directory")
  # arguments for inference
  parser.add_argument("--features_file", type=str, help="Inference mode: Provide the location of features file")
  parser.add_argument("--predictions_file", type=str, help="Inference mode: Provide Location of output file which is predicted from features file")
  parser.add_argument("--src_lang", type=str, help="Inference mode: Provide language used by source file")
  parser.add_argument("--trg_lang", type=str, default=None, help="Inference mode: Choose language that is translated from source file. NOTE: only specify for multilingual model")
  parser.add_argument("--infer_batch_size", type=int, default=None, help="Specify the batch_size to run the model with. Default use the config value.")
  parser.add_argument("--checkpoint", type=str, default=None, help="All mode: specify to load the checkpoint into model.")
  parser.add_argument("--checkpoint_idx", type=int, default=0, help="All mode: specify the epoch of the checkpoint loaded. Only useful for training.")
  parser.add_argument("--serve_path", type=str, default=None, help="File to save TorchScript model into.")
  
  args = parser.parse_args()
  # create directory if not exist
  os.makedirs(args.model_dir, exist_ok=True)
  config_path = args.config
  if(config_path is None):
    config_path = find_all_config(args.model_dir)
    print("Config path not specified, load the configs in model directory which is {}".format(config_path))
  elif(args.no_keeping_config):
    # store false variable, mean true is default
    print("Config specified, copying all to model dir")
    for subpath in config_path:
      copy(subpath, args.model_dir)
    
  # load model. Specific run mode required converting
  run_mode = OVERRIDE_RUN_MODE.get(args.run_mode, args.run_mode)
  model = models.AvailableModels[args.model](config=config_path, model_dir=args.model_dir, mode=run_mode)
  model.load_checkpoint(args.model_dir, checkpoint=args.checkpoint, checkpoint_idx=args.checkpoint_idx)
  # run model
  run_mode = args.run_mode
  if(run_mode == "train"):
    model.run_train(model_dir=args.model_dir, config=config_path)
  elif(run_mode == "eval"):
    model.run_eval(model_dir=args.model_dir, config=config_path)
  elif(run_mode == "infer"):
    model.run_infer(args.features_file, args.predictions_file, src_lang=args.src_lang, trg_lang=args.trg_lang, config=config_path, batch_size=args.infer_batch_size)
  elif(run_mode == "debug"):
    raise NotImplementedError
    model.run_debug(model_dir=args.model_dir, config=config_path)
  elif(run_mode == "serve"):
    if(args.serve_path is None):
      raise parser.ArgumentError("In serving, --serve_path cannot be empty")
    model.prepare_serve(args.serve_path, model_dir=args.model_dir, config=config_path)
  else:
    raise ValueError("Run mode {:s} not implemented.".format(run_mode))