Spaces:
Sleeping
Sleeping
File size: 1,415 Bytes
6ebe235 69d6710 5b96629 6ebe235 69d6710 6ebe235 69d6710 6ebe235 dfe2f0b 7973b9f 6ebe235 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from torch_geometric.nn import GCNConv, LGConv
from torch_geometric.utils import degree
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import HeteroData, Data
import torch_geometric.transforms as T
from torch_geometric.nn import LightGCN
import utils
device = torch.device('cpu')
data = torch.load("processed_MVL_light.pt", map_location=torch.device('cpu'))
ch = torch.load('./lightGCNModel_num_layers_MovieLens100K_checkpoint.pt', map_location=torch.device('cpu'))
lightGCNModel = LightGCN(
num_nodes=data.num_nodes,
embedding_dim=64,
num_layers=3,
).to(device)
# optimizer = torch.optim.Adam(lightGCNModel.parameters(), lr=0.005)
mask_train = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask_train]
lightGCNModel.load_state_dict(ch['model_state_dict'])
# optimizer.load_state_dict(ch['optimizer_state_dict'])
num_items = 1682
num_users = 943
def recommend(user_id):
ground_truth_items, recommendations = utils.predict(lightGCNModel, device, data, num_users, num_items, user_id, train_edge_label_index, k=5)
return ' '.join(ground_truth_items['title'].tolist()), ' '.join(recommendations)
iface = gr.Interface(fn=recommend, inputs="number", outputs=["text", "text"])
iface.launch() |