|
import torch |
|
from torch_geometric.data import Data |
|
import numpy as np |
|
import json |
|
|
|
class GNN(torch.nn.Module): |
|
""" |
|
Overall graph neural network. Consists of learnable user/item (i.e., playlist/song) embeddings |
|
and LightGCN layers. |
|
""" |
|
def __init__(self, embedding_dim, num_nodes, num_playlists, num_layers): |
|
super(GNN, self).__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.num_nodes = num_nodes |
|
self.num_playlists = num_playlists |
|
self.num_layers = num_layers |
|
|
|
|
|
|
|
self.embeddings = torch.nn.Embedding(num_embeddings=self.num_nodes, embedding_dim=self.embedding_dim) |
|
torch.nn.init.normal_(self.embeddings.weight, std=0.1) |
|
|
|
self.layers = torch.nn.ModuleList() |
|
for _ in range(self.num_layers): |
|
self.layers.append(LightGCN()) |
|
|
|
self.sigmoid = torch.sigmoid |
|
|
|
def forward(self): |
|
raise NotImplementedError("forward() has not been implemented for the GNN class. Do not use") |
|
|
|
def gnn_propagation(self, edge_index_mp): |
|
""" |
|
Performs the linear embedding propagation (using the LightGCN layers) and calculates final (multi-scale) embeddings |
|
for each user/item, which are calculated as a weighted sum of that user/item's embeddings at each layer (from |
|
0 to self.num_layers). Technically, the weighted sum here is the average, which is what the LightGCN authors recommend. |
|
|
|
args: |
|
edge_index_mp: a tensor of all (undirected) edges in the graph, which is used for message passing/propagation and |
|
calculating the multi-scale embeddings. (In contrast to the evaluation/supervision edges, which are distinct |
|
from the message passing edges and will be used for calculating loss/performance metrics). |
|
returns: |
|
final multi-scale embeddings for all users/items |
|
""" |
|
x = self.embeddings.weight |
|
|
|
x_at_each_layer = [x] |
|
for i in range(self.num_layers): |
|
x = self.layers[i](x, edge_index_mp) |
|
x_at_each_layer.append(x) |
|
final_embs = torch.stack(x_at_each_layer, dim=0).mean(dim=0) |
|
return final_embs |
|
|
|
def predict_scores(self, edge_index, embs): |
|
""" |
|
Calculates predicted scores for each playlist/song pair in the list of edges. Uses dot product of their embeddings. |
|
|
|
args: |
|
edge_index: tensor of edges (between playlists and songs) whose scores we will calculate. |
|
embs: node embeddings for calculating predicted scores (typically the multi-scale embeddings from gnn_propagation()) |
|
returns: |
|
predicted scores for each playlist/song pair in edge_index |
|
""" |
|
scores = embs[edge_index[0,:], :] * embs[edge_index[1,:], :] |
|
scores = scores.sum(dim=1) |
|
scores = self.sigmoid(scores) |
|
return scores |
|
|
|
def calc_loss(self, data_mp, data_pos, data_neg): |
|
""" |
|
The main training step. Performs GNN propagation on message passing edges, to get multi-scale embeddings. |
|
Then predicts scores for each training example, and calculates Bayesian Personalized Ranking (BPR) loss. |
|
|
|
args: |
|
data_mp: tensor of edges used for message passing / calculating multi-scale embeddings |
|
data_pos: set of positive edges that will be used during loss calculation |
|
data_neg: set of negative edges that will be used during loss calculation |
|
returns: |
|
loss calculated on the positive/negative training edges |
|
""" |
|
|
|
final_embs = self.gnn_propagation(data_mp.edge_index) |
|
|
|
|
|
pos_scores = self.predict_scores(data_pos.edge_index, final_embs) |
|
neg_scores = self.predict_scores(data_neg.edge_index, final_embs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = -torch.log(self.sigmoid(pos_scores - neg_scores)).mean() |
|
return loss |
|
|
|
def evaluation(self, data_mp, data_pos, k): |
|
""" |
|
Performs evaluation on validation or test set. Calculates recall@k. |
|
|
|
args: |
|
data_mp: message passing edges to use for propagation/calculating multi-scale embeddings |
|
data_pos: positive edges to use for scoring metrics. Should be no overlap between these edges and data_mp's edges |
|
k: value of k to use for recall@k |
|
returns: |
|
dictionary mapping playlist ID -> recall@k on that playlist |
|
""" |
|
|
|
final_embs = self.gnn_propagation(data_mp.edge_index) |
|
|
|
|
|
unique_playlists = torch.unique_consecutive(data_pos.edge_index[0,:]) |
|
playlist_emb = final_embs[unique_playlists, :] |
|
|
|
|
|
song_emb = final_embs[self.num_playlists:, :] |
|
|
|
|
|
ratings = self.sigmoid(torch.matmul(playlist_emb, song_emb.t())) |
|
|
|
|
|
result = recall_at_k(ratings.cpu(), k, self.num_playlists, data_pos.edge_index.cpu(), |
|
unique_playlists.cpu(), data_mp.edge_index.cpu()) |
|
return result |
|
|
|
|
|
|
|
data = torch.load(os.path.join(base_dir, "data_object.pt")) |
|
with open(os.path.join(base_dir, "dataset_stats.json"), 'r') as f: |
|
stats = json.load(f) |
|
num_playlists, num_nodes = stats["num_playlists"], stats["num_nodes"] |
|
model = GNN(embedding_dim=64, num_nodes=data.num_nodes, num_playlists=num_playlists, num_layers=3) |
|
model.load_state_dict(torch.load("pesos_modelo.pth")) |
|
|
|
|
|
def predict(edge_index): |
|
|
|
data = Data(edge_index=edge_index) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
output = model.gnn_propagation(data.edge_index) |
|
|
|
|
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
edge_index = np.array([[0, 1, 2], [1, 2, 0]]) |
|
predictions = predict(edge_index) |
|
print(predictions) |
|
|