File size: 1,465 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
#!/usr/bin/env python
import argparse
import torch
def main():
parser = argparse.ArgumentParser(
description="Release an OpenNMT-py model for inference")
parser.add_argument("--model", "-m",
help="The model path", required=True)
parser.add_argument("--output", "-o",
help="The output path", required=True)
parser.add_argument("--format",
choices=["pytorch", "ctranslate2"],
default="pytorch",
help="The format of the released model")
parser.add_argument("--quantization", "-q",
choices=["int8", "int16", "float16", "int8_float16"],
default=None,
help="Quantization type for CT2 model.")
opt = parser.parse_args()
model = torch.load(opt.model, map_location=torch.device("cpu"))
if opt.format == "pytorch":
model["optim"] = None
torch.save(model, opt.output)
elif opt.format == "ctranslate2":
import ctranslate2
if not hasattr(ctranslate2, "__version__"):
raise RuntimeError(
"onmt_release_model script requires ctranslate2 >= 2.0.0"
)
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
converter.convert(opt.output, force=True,
quantization=opt.quantization)
if __name__ == "__main__":
main()
|