sakharamg's picture
Uploading all files
158b61b
#!/usr/bin/env python
import argparse
import torch
def main():
parser = argparse.ArgumentParser(
description="Release an OpenNMT-py model for inference")
parser.add_argument("--model", "-m",
help="The model path", required=True)
parser.add_argument("--output", "-o",
help="The output path", required=True)
parser.add_argument("--format",
choices=["pytorch", "ctranslate2"],
default="pytorch",
help="The format of the released model")
parser.add_argument("--quantization", "-q",
choices=["int8", "int16", "float16", "int8_float16"],
default=None,
help="Quantization type for CT2 model.")
opt = parser.parse_args()
model = torch.load(opt.model, map_location=torch.device("cpu"))
if opt.format == "pytorch":
model["optim"] = None
torch.save(model, opt.output)
elif opt.format == "ctranslate2":
import ctranslate2
if not hasattr(ctranslate2, "__version__"):
raise RuntimeError(
"onmt_release_model script requires ctranslate2 >= 2.0.0"
)
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
converter.convert(opt.output, force=True,
quantization=opt.quantization)
if __name__ == "__main__":
main()