''' Converts ESM model state checkpoints to pytorch model data. ''' import argparse import os import pathlib import torch def create_parser(): parser = argparse.ArgumentParser( description="Convert model state to model data for ESM." ) parser.add_argument( "model_data_location", type=pathlib.Path, help="Model data filepath", ) parser.add_argument( "model_state_location", type=pathlib.Path, help="Model state filepath", ) parser.add_argument( "output_dir", type=pathlib.Path, help="output directory", ) return parser def main(args): model_data = torch.load(args.model_data_location, map_location='cpu') state = torch.load(args.model_state_location, map_location='cpu') model_data['model'] = state args.output_dir.mkdir(parents=True, exist_ok=True) torch.save(model_data, os.path.join(args.output_dir, 'model_data.pt')) if __name__ == "__main__": parser = create_parser() args = parser.parse_args() main(args)