Spaces:
Running
Running
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import torch | |
import torch.nn as nn | |
from configuration_indictrans import IndicTransConfig | |
from modeling_indictrans import IndicTransForConditionalGeneration | |
def remove_ignore_keys_(state_dict): | |
ignore_keys = [ | |
"encoder.version", | |
"decoder.version", | |
"model.encoder.version", | |
"model.decoder.version", | |
"_float_tensor", | |
"encoder.embed_positions._float_tensor", | |
"decoder.embed_positions._float_tensor", | |
] | |
for k in ignore_keys: | |
state_dict.pop(k, None) | |
def make_linear_from_emb(emb): | |
vocab_size, emb_size = emb.shape | |
lin_layer = nn.Linear(vocab_size, emb_size, bias=False) | |
lin_layer.weight.data = emb.data | |
return lin_layer | |
def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path): | |
model = torch.load(checkpoint_path, map_location="cpu") | |
args = model["args"] or model["cfg"]["model"] | |
state_dict = model["model"] | |
remove_ignore_keys_(state_dict) | |
encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] | |
decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0] | |
config = IndicTransConfig( | |
encoder_vocab_size=encoder_vocab_size, | |
decoder_vocab_size=decoder_vocab_size, | |
max_source_positions=args.max_source_positions, | |
max_target_positions=args.max_target_positions, | |
encoder_layers=args.encoder_layers, | |
decoder_layers=args.decoder_layers, | |
layernorm_embedding=args.layernorm_embedding, | |
encoder_normalize_before=args.encoder_normalize_before, | |
decoder_normalize_before=args.decoder_normalize_before, | |
encoder_attention_heads=args.encoder_attention_heads, | |
decoder_attention_heads=args.decoder_attention_heads, | |
encoder_ffn_dim=args.encoder_ffn_embed_dim, | |
decoder_ffn_dim=args.decoder_ffn_embed_dim, | |
encoder_embed_dim=args.encoder_embed_dim, | |
decoder_embed_dim=args.decoder_embed_dim, | |
encoder_layerdrop=args.encoder_layerdrop, | |
decoder_layerdrop=args.decoder_layerdrop, | |
dropout=args.dropout, | |
attention_dropout=args.attention_dropout, | |
activation_dropout=args.activation_dropout, | |
activation_function=args.activation_fn, | |
share_decoder_input_output_embed=args.share_decoder_input_output_embed, | |
scale_embedding=not args.no_scale_embedding, | |
) | |
model = IndicTransForConditionalGeneration(config) | |
model.model.load_state_dict(state_dict, strict=False) | |
if not args.share_decoder_input_output_embed: | |
model.lm_head = make_linear_from_emb( | |
state_dict["decoder.output_projection.weight"] | |
) | |
print(model) | |
return model | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument( | |
"--fairseq_path", | |
default="indic-en/model/checkpoint_best.pt", | |
type=str, | |
help="path to a model.pt on local filesystem.", | |
) | |
parser.add_argument( | |
"--pytorch_dump_folder_path", | |
default="indic-en/hf_model", | |
type=str, | |
help="Path to the output PyTorch model.", | |
) | |
args = parser.parse_args() | |
model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path) | |
model.save_pretrained(args.pytorch_dump_folder_path) | |