import math import os import torch from loguru import logger from torch import nn from torch.nn import functional as F from torch_geometric.nn import RGCNConv class KGPrompt(nn.Module): def __init__( self, hidden_size, token_hidden_size, n_head, n_layer, n_block, n_entity, num_relations, num_bases, edge_index, edge_type, n_prefix_rec=None, n_prefix_conv=None, ): super(KGPrompt, self).__init__() self.hidden_size = hidden_size self.n_head = n_head self.head_dim = hidden_size // n_head self.n_layer = n_layer self.n_block = n_block self.n_prefix_rec = n_prefix_rec self.n_prefix_conv = n_prefix_conv entity_hidden_size = hidden_size // 2 self.kg_encoder = RGCNConv( entity_hidden_size, entity_hidden_size, num_relations=num_relations, num_bases=num_bases, ) self.node_embeds = nn.Parameter(torch.empty(n_entity, entity_hidden_size)) stdv = math.sqrt(6.0 / (self.node_embeds.size(-2) + self.node_embeds.size(-1))) self.node_embeds.data.uniform_(-stdv, stdv) self.edge_index = nn.Parameter(edge_index, requires_grad=False) self.edge_type = nn.Parameter(edge_type, requires_grad=False) self.entity_proj1 = nn.Sequential( nn.Linear(entity_hidden_size, entity_hidden_size // 2), nn.ReLU(), nn.Linear(entity_hidden_size // 2, entity_hidden_size), ) self.entity_proj2 = nn.Linear(entity_hidden_size, hidden_size) self.token_proj1 = nn.Sequential( nn.Linear(token_hidden_size, token_hidden_size // 2), nn.ReLU(), nn.Linear(token_hidden_size // 2, token_hidden_size), ) self.token_proj2 = nn.Linear(token_hidden_size, hidden_size) self.cross_attn = nn.Linear(hidden_size, hidden_size, bias=False) self.prompt_proj1 = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size), ) self.prompt_proj2 = nn.Linear(hidden_size, n_layer * n_block * hidden_size) if self.n_prefix_rec is not None: self.rec_prefix_embeds = nn.Parameter( torch.empty(n_prefix_rec, hidden_size) ) nn.init.normal_(self.rec_prefix_embeds) self.rec_prefix_proj = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size), ) if self.n_prefix_conv is not None: self.conv_prefix_embeds = nn.Parameter( torch.empty(n_prefix_conv, hidden_size) ) nn.init.normal_(self.conv_prefix_embeds) self.conv_prefix_proj = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size), ) def set_and_fix_node_embed(self, node_embeds: torch.Tensor): self.node_embeds.data = node_embeds self.node_embeds.requires_grad_(False) def get_entity_embeds(self): node_embeds = self.node_embeds entity_embeds = ( self.kg_encoder(node_embeds, self.edge_index, self.edge_type) + node_embeds ) entity_embeds = self.entity_proj1(entity_embeds) + entity_embeds entity_embeds = self.entity_proj2(entity_embeds) return entity_embeds def forward( self, entity_ids=None, token_embeds=None, output_entity=False, use_rec_prefix=False, use_conv_prefix=False, entity_embeds=None, ): batch_size, entity_len, token_len = None, None, None if entity_embeds is not None: batch_size, entity_len = entity_embeds.shape[:2] elif entity_ids is not None: batch_size, entity_len = entity_ids.shape[:2] entity_embeds = self.get_entity_embeds() entity_embeds = entity_embeds[ entity_ids ] # (batch_size, entity_len, hidden_size) if token_embeds is not None: batch_size, token_len = token_embeds.shape[:2] token_embeds = ( self.token_proj1(token_embeds) + token_embeds ) # (batch_size, token_len, hidden_size) token_embeds = self.token_proj2(token_embeds) if entity_embeds is not None and token_embeds is not None: attn_weights = self.cross_attn(token_embeds) @ entity_embeds.permute( 0, 2, 1 ) # (batch_size, token_len, entity_len) attn_weights /= self.hidden_size if output_entity: token_weights = F.softmax(attn_weights, dim=1).permute(0, 2, 1) prompt_embeds = token_weights @ token_embeds + entity_embeds prompt_len = entity_len else: entity_weights = F.softmax(attn_weights, dim=2) prompt_embeds = entity_weights @ entity_embeds + token_embeds prompt_len = token_len elif entity_embeds is not None: prompt_embeds = entity_embeds prompt_len = entity_len else: prompt_embeds = token_embeds prompt_len = token_len if self.n_prefix_rec is not None and use_rec_prefix: prefix_embeds = ( self.rec_prefix_proj(self.rec_prefix_embeds) + self.rec_prefix_embeds ) prefix_embeds = prefix_embeds.expand(prompt_embeds.shape[0], -1, -1) prompt_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1) prompt_len += self.n_prefix_rec if self.n_prefix_conv is not None and use_conv_prefix: prefix_embeds = ( self.conv_prefix_proj(self.conv_prefix_embeds) + self.conv_prefix_embeds ) prefix_embeds = prefix_embeds.expand(prompt_embeds.shape[0], -1, -1) prompt_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1) prompt_len += self.n_prefix_conv prompt_embeds = self.prompt_proj1(prompt_embeds) + prompt_embeds prompt_embeds = self.prompt_proj2(prompt_embeds) prompt_embeds = prompt_embeds.reshape( batch_size, prompt_len, self.n_layer, self.n_block, self.n_head, self.head_dim, ).permute( 2, 3, 0, 4, 1, 5 ) # (n_layer, n_block, batch_size, n_head, prompt_len, head_dim) return prompt_embeds def save(self, save_dir): os.makedirs(save_dir, exist_ok=True) state_dict = {k: v for k, v in self.state_dict().items() if "edge" not in k} save_path = os.path.join(save_dir, "model.pt") torch.save(state_dict, save_path) def load(self, load_dir): load_path = os.path.join(load_dir, "model.pt") missing_keys, unexpected_keys = self.load_state_dict( torch.load(load_path, map_location=torch.device("cpu")), strict=False ) logger.info(f"missing_keys: {missing_keys}, unexpected_keys: {unexpected_keys}")