yuyan-10b / megatron /model /realm_model.py
Shawn001's picture
Upload 131 files
23bd7af
import os
import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
ict_head_size,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__()
bert_kwargs = dict(
ict_head_size=ict_head_size,
num_tokentypes=num_tokentypes,
parallel_output=parallel_output
)
assert not (only_block_model and only_query_model)
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = IREncoderBertModel(**bert_kwargs)
self._query_key = 'question_model'
if self.use_block_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = IREncoderBertModel(**bert_kwargs)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
"""Run a forward pass for each of the models and return the respective embeddings."""
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
return query_logits, block_logits
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
raise ValueError("Cannot embed query without query model.")
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.use_query_model:
print("Loading ICT query model", flush=True)
self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict)
if self.use_block_model:
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
def init_state_dict_from_bert(self):
"""Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
args = get_args()
tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT load for ICT")
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except BaseException:
raise ValueError("Could not load checkpoint")
# load the LM state dict into each model
model_dict = state_dict['model']['language_model']
self.query_model.language_model.load_state_dict(model_dict)
self.block_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well
query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
class IREncoderBertModel(MegatronModule):
"""BERT-based encoder for queries or blocks used for learned information retrieval."""
def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
super(IREncoderBertModel, self).__init__()
args = get_args()
self.ict_head_size = ict_head_size
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
self._ict_head_key = 'ict_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.ict_head.load_state_dict(
state_dict[self._ict_head_key], strict=strict)