|
import math |
|
import os |
|
from typing import Optional, List, Union, Tuple |
|
|
|
import torch |
|
from loguru import logger |
|
from torch import nn |
|
from torch.nn import functional as F, CrossEntropyLoss |
|
from torch_geometric.nn import RGCNConv |
|
from transformers import BartPretrainedModel, BartConfig, BartModel |
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
|
import sys |
|
|
|
sys.path.append("..") |
|
|
|
from src.model.utils import SelfAttention, shift_tokens_right |
|
|
|
|
|
class KBRDforRec(nn.Module): |
|
def __init__(self, hidden_size, num_relations, num_bases, num_entities): |
|
super(KBRDforRec, self).__init__() |
|
|
|
self.kg_encoder = RGCNConv( |
|
hidden_size, hidden_size, num_relations=num_relations, num_bases=num_bases |
|
) |
|
self.node_embeds = nn.Parameter(torch.empty(num_entities, hidden_size)) |
|
stdv = math.sqrt(6.0 / (self.node_embeds.size(-2) + self.node_embeds.size(-1))) |
|
self.node_embeds.data.uniform_(-stdv, stdv) |
|
|
|
self.special_token_embeddings = nn.Parameter( |
|
torch.zeros(1, hidden_size), requires_grad=False |
|
) |
|
|
|
self.attn = SelfAttention(hidden_size) |
|
|
|
def get_node_embeds(self, edge_index, edge_type): |
|
node_embeds = self.kg_encoder(self.node_embeds, edge_index, edge_type) |
|
node_embeds = torch.cat([node_embeds, self.special_token_embeddings], dim=0) |
|
return node_embeds |
|
|
|
def forward( |
|
self, |
|
entity_embeds=None, |
|
entity_ids=None, |
|
edge_index=None, |
|
edge_type=None, |
|
node_embeds=None, |
|
entity_mask=None, |
|
labels=None, |
|
reduction="none", |
|
): |
|
if node_embeds is None: |
|
node_embeds = self.get_node_embeds(edge_index, edge_type) |
|
|
|
if entity_embeds is None: |
|
entity_embeds = node_embeds[entity_ids] |
|
|
|
user_embeds = self.attn(entity_embeds, entity_mask) |
|
|
|
logits = user_embeds @ node_embeds.T |
|
loss = None |
|
if labels is not None: |
|
loss = F.cross_entropy(logits, labels, reduction=reduction) |
|
|
|
return {"loss": loss, "logit": logits, "user_embeds": user_embeds} |
|
|
|
def save(self, save_dir): |
|
os.makedirs(save_dir, exist_ok=True) |
|
save_path = os.path.join(save_dir, "model.pt") |
|
torch.save(self.state_dict(), save_path) |
|
|
|
def load(self, load_dir): |
|
load_path = os.path.join(load_dir, "model.pt") |
|
missing_keys, unexpected_keys = self.load_state_dict( |
|
torch.load(load_path, map_location=torch.device("cpu")) |
|
) |
|
|
|
|
|
class KBRDforConv(BartPretrainedModel): |
|
base_model_prefix = "model" |
|
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"] |
|
|
|
def __init__(self, config: BartConfig, user_hidden_size): |
|
super().__init__(config) |
|
self.model = BartModel(config) |
|
self.register_buffer( |
|
"final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)) |
|
) |
|
self.lm_head = nn.Linear( |
|
config.d_model, self.model.shared.num_embeddings, bias=False |
|
) |
|
|
|
self.rec_proj = nn.Linear(user_hidden_size, self.model.shared.num_embeddings) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_encoder(self): |
|
return self.model.get_encoder() |
|
|
|
def get_decoder(self): |
|
return self.model.get_decoder() |
|
|
|
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: |
|
new_embeddings = super().resize_token_embeddings(new_num_tokens) |
|
self._resize_final_logits_bias(new_num_tokens) |
|
return new_embeddings |
|
|
|
def _resize_final_logits_bias(self, new_num_tokens: int) -> None: |
|
old_num_tokens = self.final_logits_bias.shape[-1] |
|
if new_num_tokens <= old_num_tokens: |
|
new_bias = self.final_logits_bias[:, :new_num_tokens] |
|
else: |
|
extra_bias = torch.zeros( |
|
(1, new_num_tokens - old_num_tokens), |
|
device=self.final_logits_bias.device, |
|
) |
|
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) |
|
self.register_buffer("final_logits_bias", new_bias) |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
decoder_head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
decoder_user_embeds=None, |
|
) -> Union[Tuple, Seq2SeqLMOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if labels is not None: |
|
if use_cache: |
|
logger.warning( |
|
"The `use_cache` argument is changed to `False` since `labels` is provided." |
|
) |
|
use_cache = False |
|
if decoder_input_ids is None and decoder_inputs_embeds is None: |
|
decoder_input_ids = shift_tokens_right( |
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
) |
|
|
|
outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
encoder_outputs=encoder_outputs, |
|
decoder_attention_mask=decoder_attention_mask, |
|
head_mask=head_mask, |
|
decoder_head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
decoder_inputs_embeds=decoder_inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
lm_logits = ( |
|
self.lm_head(outputs[0]) |
|
+ self.final_logits_bias |
|
+ self.rec_proj(decoder_user_embeds).unsqueeze(1) |
|
) |
|
|
|
masked_lm_loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
masked_lm_loss = loss_fct( |
|
lm_logits.view(-1, self.config.vocab_size), labels.view(-1) |
|
) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return ( |
|
((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
) |
|
|
|
return Seq2SeqLMOutput( |
|
loss=masked_lm_loss, |
|
logits=lm_logits, |
|
past_key_values=outputs.past_key_values, |
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
decoder_attentions=outputs.decoder_attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
encoder_attentions=outputs.encoder_attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
past=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
cross_attn_head_mask=None, |
|
use_cache=None, |
|
encoder_outputs=None, |
|
decoder_user_embeds=None, |
|
**kwargs |
|
): |
|
|
|
if past is not None: |
|
decoder_input_ids = decoder_input_ids[:, -1:] |
|
|
|
return { |
|
"input_ids": None, |
|
"encoder_outputs": encoder_outputs, |
|
"past_key_values": past, |
|
"decoder_input_ids": decoder_input_ids, |
|
"attention_mask": attention_mask, |
|
"head_mask": head_mask, |
|
"decoder_head_mask": decoder_head_mask, |
|
"cross_attn_head_mask": cross_attn_head_mask, |
|
"use_cache": use_cache, |
|
"decoder_user_embeds": decoder_user_embeds, |
|
} |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return shift_tokens_right( |
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
) |
|
|
|
@staticmethod |
|
def _reorder_cache(past, beam_idx): |
|
reordered_past = () |
|
for layer_past in past: |
|
|
|
reordered_past += ( |
|
tuple( |
|
past_state.index_select(0, beam_idx) |
|
for past_state in layer_past[:2] |
|
) |
|
+ layer_past[2:], |
|
) |
|
return reordered_past |
|
|