NMT-LaVi / bin /serve.py
hieungo1410's picture
'add'
8cb4f3b
raw
history blame
4.51 kB
import os
#import utils.save as saver
#import models
#from models.transformer import Transformer
#from modules.config import find_all_config
class TransformerHandlerClass:
def __init__(self):
self.model = None
self.device = None
self.initialized = False
def _find_checkpoint(self, model_dir, best_model_prefix="best_model", model_prefix="model", validate=True):
"""Attempt to retrieve the best model checkpoint from model_dir. Failing that, the model of the latest iteration.
Args:
model_dir: location to search for checkpoint. str
Returns:
single str denoting the checkpoint path """
score_file_path = os.path.join(model_dir, saver.BEST_MODEL_FILE)
if(os.path.isfile(score_file_path)): # score exist -> best model
best_model_path = os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(best_model_prefix, 0, saver.MODEL_EXTENSION))
if(validate):
assert os.path.isfile(best_model_path), "Score file is available, but file {:s} is missing.".format(best_model_path)
return best_model_path
else: # score not exist -> latest model
last_checkpoint_idx = saver.check_model_in_path(name_prefix=model_prefix)
if(last_checkpoint_idx == 0):
raise ValueError("No checkpoint found in folder {:s} with prefix {:s}.".format(model_dir, model_prefix))
else:
return os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(model_prefix, last_checkpoint_idx, saver.MODEL_EXTENSION))
def initialize(self, ctx):
manifest = ctx.manifest
properties = ctx.system_properties
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
self.model_dir = model_dir = properties.get("model_dir")
# extract checkpoint location, config & model name
model_serve_file = os.path.join(model_dir, saver.MODEL_SERVE_FILE)
with io.open(model_serve_file, "r") as serve_config:
model_name = serve_config.read().strip()
# model_cls = models.AvailableModels[model_name]
model_cls = Transformer # can't select due to nature of model file
checkpoint_path = manifest['model'].get('serializedFile', self._find_checkpoint(model_dir)) # attempt to use the checkpoint fed from archiver; else use the best checkpoint found
config_path = find_all_config(model_dir)
# load model with inbuilt config + vocab & without pretraining data
self.model = model = model_cls(config=config_path, model_dir=model_dir, mode="infer")
model.load_checkpoint(args.model_dir, checkpoint=checkpoint_path) # TODO find_checkpoint might do some redundant thing here since load_checkpoint had already done searching for latest
print("Model {:s} loaded successfully at location {:s}.".format(model_name, model_dir))
self.initialized = True
def handle(self, data):
"""The main bulk of handling. Process a batch of data received from client.
Args:
data: the object received from client. Should contain something in [batch_size] of str
Returns:
the expected translation, [batch_size] of str
"""
batch_sentences = data[0].get("data")
# assert batch_sentences is not None, "data is {}".format(data)
# make sure that sentences are detokenized before returning
translated_sentences = self.model.translate_batch(batch_sentences, output_tokens=False)
return translated_sentences
class BeamSearchHandlerClass:
def __init__(self):
self.model = None
self.inferrer = None
self.initialized = False
def initialize(self, ctx):
manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties['model_dir']
ts_modelpath = manifest['model']['serializedFile']
self.model = ts_model = torch.jit.load(os.path.join(model_dir, ts_modelpath))
from modules.inference.beam_search import BeamSearch
device = "cuda" if torch.cuda.is_available() else "cpu"
self.inferrer = BeamSearch(model, 160, device, beam_size=5)
self.initialized = True
def handle(self, data):
batch_sentences = data[0].get("data")
# assert batch_sentences is not None, "data is {}".format(data)
translated_sentences = self.inferrer.translate_batch_sentence(data, output_tokens=False)
return translated_sentences
RUNNING_MODEL = BeamSearchHandlerClass()
def handle(data, context):
if(not RUNNING_MODEL.initialized): # Lazy init
RUNNING_MODEL.initialize(context)
if(data is None):
return None
return RUNNING_MODEL.handle(data)