from transformers import PreTrainedModel, PretrainedConfig from transformers import AutoModel, AutoConfig import torch import torch.nn as nn import math import random class RetrieverConfig(PretrainedConfig): model_type = "retriever" def __init__( self, encoder_model_name="microsoft/deberta-v3-large", max_seq_len=512, mean_passage_len=70, beam_size=1, gradient_checkpointing=False, use_label_order=False, use_negative_sampling=False, use_focal=False, use_early_stop=True, **kwargs ): super().__init__(**kwargs) self.encoder_model_name = encoder_model_name self.max_seq_len = max_seq_len self.mean_passage_len = mean_passage_len self.beam_size = beam_size self.gradient_checkpointing = gradient_checkpointing self.use_label_order = use_label_order self.use_negative_sampling = use_negative_sampling self.use_focal = use_focal self.use_early_stop = use_early_stop class Retriever(PreTrainedModel): config_class = RetrieverConfig def __init__(self, config): super().__init__(config) encoder_config = AutoConfig.from_pretrained(config.encoder_model_name) self.encoder = AutoModel.from_pretrained( config.encoder_model_name, config=encoder_config ) self.hop_classifier_layer = nn.Linear(encoder_config.hidden_size, 2) self.hop_n_classifier_layer = nn.Linear(encoder_config.hidden_size, 2) if config.gradient_checkpointing: self.encoder.gradient_checkpointing_enable() # Initialize weights and apply final processing self.post_init() def get_negative_sampling_results(self, context_ids, current_preds, sf_idx): closest_power_of_2 = 2 ** math.floor(math.log2(self.beam_size)) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(0.5, powers) each_sampling_nums = [max(1, int(len(context_ids) * item)) for item in slopes] last_pred_idx = set() sampled_set = {} for i in range(self.beam_size): last_pred_idx.add(current_preds[i][-1]) sampled_set[i] = [] for j in range(len(context_ids)): if j in current_preds[i] or j in last_pred_idx: continue if set(current_preds[i] + [j]) == set(sf_idx): continue sampled_set[i].append(j) random.shuffle(sampled_set[i]) sampled_set[i] = sampled_set[i][: each_sampling_nums[i]] return sampled_set def forward(self, q_codes, c_codes, sf_idx, hop=0): """ hop predefined """ device = q_codes[0].device total_loss = torch.tensor(0.0, device=device, requires_grad=True) # the input ids of predictions and questions remained by last hop last_prediction = None pre_question_ids = None loss_function = nn.CrossEntropyLoss() focal_loss_function = None if self.use_focal: focal_loss_function = FocalLoss() question_ids = q_codes[0] context_ids = c_codes[0] current_preds = [] if self.training: sf_idx = sf_idx[0] sf = sf_idx hops = len(sf) else: hops = hop if hop > 0 else len(sf_idx[0]) if len(context_ids) <= hops or hops < 1: return {"current_preds": [list(range(hops))], "loss": total_loss} mean_passage_len = (self.max_seq_len - 2 - question_ids.shape[-1]) // hops for idx in range(hops): if idx == 0: # first hop qp_len = [ min( self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len, question_ids.shape[-1] + c.shape[-1], ) for c in context_ids ] next_question_ids = [] hop1_qp_ids = torch.zeros( [len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long ) hop1_qp_attention_mask = torch.zeros( [len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long ) if self.training: hop1_label = torch.zeros( [len(context_ids)], dtype=torch.long, device=device ) for i in range(len(context_ids)): this_question_ids = torch.cat((question_ids, context_ids[i]))[ : qp_len[i] ] hop1_qp_ids[i, 1 : qp_len[i] + 1] = this_question_ids.view(-1) hop1_qp_ids[i, 0] = self.config.cls_token_id hop1_qp_ids[i, qp_len[i] + 1] = self.config.sep_token_id hop1_qp_attention_mask[i, : qp_len[i] + 1] = 1 if self.training: if self.use_label_order: if i == sf_idx[0]: hop1_label[i] = 1 else: if i in sf_idx: hop1_label[i] = 1 next_question_ids.append(this_question_ids) hop1_encoder_outputs = self.encoder( input_ids=hop1_qp_ids, attention_mask=hop1_qp_attention_mask )[0][ :, 0, : ] # [doc_num, hidden_size] if self.training and self.gradient_checkpointing: hop1_projection = torch.utils.checkpoint.checkpoint( self.hop_classifier_layer, hop1_encoder_outputs ) # [doc_num, 2] else: hop1_projection = self.hop_classifier_layer( hop1_encoder_outputs ) # [doc_num, 2] if self.training: total_loss = total_loss + loss_function(hop1_projection, hop1_label) _, hop1_pred_documents = hop1_projection[:, 1].topk( self.beam_size, dim=-1 ) last_prediction = ( hop1_pred_documents # used for taking new_question_ids ) pre_question_ids = next_question_ids current_preds = [ [item.item()] for item in hop1_pred_documents ] # used for taking the orginal passage index of the current passage else: # set up the vectors outside the beam_size loop qp_len_total = {} max_qp_len = 0 last_pred_idx = set() if self.training: # stop predicting if the current hop's predictions are wrong flag = False for i in range(self.beam_size): if self.use_label_order: if current_preds[i][-1] == sf_idx[idx - 1]: flag = True break else: if set(current_preds[i]) == set(sf_idx[:idx]): flag = True break if not flag and self.use_early_stop: break for i in range(self.beam_size): # expand the search space, and self.beam_size is the number of predicted passages pred_doc = last_prediction[i] # avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9 last_pred_idx.add(current_preds[i][-1]) new_question_ids = pre_question_ids[pred_doc] qp_len = {} # obtain the sequence length which can be formed into the vector for j in range(len(context_ids)): if j in current_preds[i] or j in last_pred_idx: continue qp_len[j] = min( self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len, new_question_ids.shape[-1] + context_ids[j].shape[-1], ) max_qp_len = max(max_qp_len, qp_len[j]) qp_len_total[i] = qp_len if len(qp_len_total) < 1: # skip if all the predictions in the last hop are wrong break if self.use_negative_sampling and self.training: # deprecated current_sf = [sf_idx[idx]] if self.use_label_order else sf_idx sampled_set = self.get_negative_sampling_results( context_ids, current_preds, sf_idx[: idx + 1] ) vector_num = 1 for k in range(self.beam_size): vector_num += len(sampled_set[k]) else: vector_num = sum([len(v) for k, v in qp_len_total.items()]) # set up the vectors hop_qp_ids = torch.zeros( [vector_num, max_qp_len + 2], device=device, dtype=torch.long ) hop_qp_attention_mask = torch.zeros( [vector_num, max_qp_len + 2], device=device, dtype=torch.long ) if self.training: hop_label = torch.zeros( [vector_num], dtype=torch.long, device=device ) vec_idx = 0 pred_mapping = [] next_question_ids = [] last_pred_idx = set() for i in range(self.beam_size): # expand the search space, and self.beam_size is the number of predicted passages pred_doc = last_prediction[i] # avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9 last_pred_idx.add(current_preds[i][-1]) new_question_ids = pre_question_ids[pred_doc] for j in range(len(context_ids)): if j in current_preds[i] or j in last_pred_idx: continue if self.training and self.use_negative_sampling: if j not in sampled_set[i] and not ( set(current_preds[i] + [j]) == set(sf_idx[: idx + 1]) ): continue # shuffle the order between documents pre_context_ids = ( new_question_ids[question_ids.shape[-1] :].clone().detach() ) context_list = [pre_context_ids, context_ids[j]] if self.training: random.shuffle(context_list) this_question_ids = torch.cat( ( question_ids, torch.cat((context_list[0], context_list[1])), ) )[: qp_len_total[i][j]] next_question_ids.append(this_question_ids) hop_qp_ids[ vec_idx, 1 : qp_len_total[i][j] + 1 ] = this_question_ids hop_qp_ids[vec_idx, 0] = self.config.cls_token_id hop_qp_ids[ vec_idx, qp_len_total[i][j] + 1 ] = self.config.sep_token_id hop_qp_attention_mask[vec_idx, : qp_len_total[i][j] + 1] = 1 if self.training: if self.use_negative_sampling: if set(current_preds[i] + [j]) == set( sf_idx[: idx + 1] ): hop_label[vec_idx] = 1 else: # if self.use_label_order: if set(current_preds[i] + [j]) == set( sf_idx[: idx + 1] ): hop_label[vec_idx] = 1 # else: # if j in sf_idx: # hop_label[vec_idx] = 1 pred_mapping.append(current_preds[i] + [j]) vec_idx += 1 assert len(pred_mapping) == hop_qp_ids.shape[0] hop_encoder_outputs = self.encoder( input_ids=hop_qp_ids, attention_mask=hop_qp_attention_mask )[0][ :, 0, : ] # [vec_num, hidden_size] # if idx == 1: # hop_projection_func = self.hop2_classifier_layer # elif idx == 2: # hop_projection_func = self.hop3_classifier_layer # else: # hop_projection_func = self.hop4_classifier_layer hop_projection_func = self.hop_n_classifier_layer if self.training and self.gradient_checkpointing: hop_projection = torch.utils.checkpoint.checkpoint( hop_projection_func, hop_encoder_outputs ) # [vec_num, 2] else: hop_projection = hop_projection_func( hop_encoder_outputs ) # [vec_num, 2] if self.training: if not self.use_focal: total_loss = total_loss + loss_function( hop_projection, hop_label ) else: total_loss = total_loss + focal_loss_function( hop_projection, hop_label ) _, hop_pred_documents = hop_projection[:, 1].topk( self.beam_size, dim=-1 ) last_prediction = hop_pred_documents pre_question_ids = next_question_ids current_preds = [ pred_mapping[hop_pred_documents[i].item()] for i in range(self.beam_size) ] res = {"current_preds": current_preds, "loss": total_loss} return res @staticmethod def convert_from_torch_state_dict_to_hf( state_dict_path, hf_checkpoint_path, config ): """ Converts a PyTorch state dict to a Hugging Face pretrained checkpoint. :param state_dict_path: Path to the PyTorch state dict file. :param hf_checkpoint_path: Path where the Hugging Face checkpoint will be saved. :param config: An instance of RetrieverConfig or a dictionary for the model's configuration. """ # Load the configuration if isinstance(config, dict): config = RetrieverConfig(**config) # Initialize the model model = Retriever(config) # Load the state dict state_dict = torch.load(state_dict_path) model.load_state_dict(state_dict) # Save as a Hugging Face checkpoint model.save_pretrained(hf_checkpoint_path) @staticmethod def save_encoder_to_hf(state_dict_path, hf_checkpoint_path, config): """ Saves only the encoder part of the model to a specified Hugging Face checkpoint path. :param model: An instance of the Retriever model. :param hf_checkpoint_path: Path where the encoder checkpoint will be saved on Hugging Face. """ # Load the configuration if isinstance(config, dict): config = RetrieverConfig(**config) # Initialize the model model = Retriever(config) # Load the state dict state_dict = torch.load(state_dict_path) model.load_state_dict(state_dict) # Extract the encoder encoder = model.encoder # Save the encoder using Hugging Face's save_pretrained method encoder.save_pretrained(hf_checkpoint_path) model = Retriever.from_pretrained("scholarly-shadows-syndicate/beam_retriever_unofficial")