CRSArena / src /model /kbrd /kbrd_model.py
Nolwenn
Initial commit
b599481
raw
history blame
9.54 kB
import math
import os
from typing import Optional, List, Union, Tuple
import torch
from loguru import logger
from torch import nn
from torch.nn import functional as F, CrossEntropyLoss
from torch_geometric.nn import RGCNConv
from transformers import BartPretrainedModel, BartConfig, BartModel
from transformers.modeling_outputs import Seq2SeqLMOutput
import sys
sys.path.append("..")
from src.model.utils import SelfAttention, shift_tokens_right
class KBRDforRec(nn.Module):
def __init__(self, hidden_size, num_relations, num_bases, num_entities):
super(KBRDforRec, self).__init__()
# kg encoder
self.kg_encoder = RGCNConv(
hidden_size, hidden_size, num_relations=num_relations, num_bases=num_bases
)
self.node_embeds = nn.Parameter(torch.empty(num_entities, hidden_size))
stdv = math.sqrt(6.0 / (self.node_embeds.size(-2) + self.node_embeds.size(-1)))
self.node_embeds.data.uniform_(-stdv, stdv)
self.special_token_embeddings = nn.Parameter(
torch.zeros(1, hidden_size), requires_grad=False
)
self.attn = SelfAttention(hidden_size)
def get_node_embeds(self, edge_index, edge_type):
node_embeds = self.kg_encoder(self.node_embeds, edge_index, edge_type)
node_embeds = torch.cat([node_embeds, self.special_token_embeddings], dim=0)
return node_embeds
def forward(
self,
entity_embeds=None,
entity_ids=None,
edge_index=None,
edge_type=None,
node_embeds=None,
entity_mask=None,
labels=None,
reduction="none",
):
if node_embeds is None:
node_embeds = self.get_node_embeds(edge_index, edge_type)
if entity_embeds is None:
entity_embeds = node_embeds[entity_ids] # (bs, seq_len, hs)
user_embeds = self.attn(entity_embeds, entity_mask)
logits = user_embeds @ node_embeds.T # (bs, n_node)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels, reduction=reduction)
return {"loss": loss, "logit": logits, "user_embeds": user_embeds}
def save(self, save_dir):
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "model.pt")
torch.save(self.state_dict(), save_path)
def load(self, load_dir):
load_path = os.path.join(load_dir, "model.pt")
missing_keys, unexpected_keys = self.load_state_dict(
torch.load(load_path, map_location=torch.device("cpu"))
)
class KBRDforConv(BartPretrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
def __init__(self, config: BartConfig, user_hidden_size):
super().__init__(config)
self.model = BartModel(config)
self.register_buffer(
"final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
)
self.lm_head = nn.Linear(
config.d_model, self.model.shared.num_embeddings, bias=False
)
self.rec_proj = nn.Linear(user_hidden_size, self.model.shared.num_embeddings)
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros(
(1, new_num_tokens - old_num_tokens),
device=self.final_logits_bias.device,
)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
decoder_user_embeds=None,
) -> Union[Tuple, Seq2SeqLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
if use_cache:
logger.warning(
"The `use_cache` argument is changed to `False` since `labels` is provided."
)
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = (
self.lm_head(outputs[0])
+ self.final_logits_bias
+ self.rec_proj(decoder_user_embeds).unsqueeze(1)
)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(
lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
decoder_user_embeds=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"decoder_user_embeds": decoder_user_embeds,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past[:2]
)
+ layer_past[2:],
)
return reordered_past