Vermeer commited on
Commit
6ebe235
·
verified ·
1 Parent(s): b7a65f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import pickle
7
+
8
+ from torch_geometric.nn import GCNConv, LGConv
9
+ from torch_geometric.utils import degree
10
+ from torch_geometric.nn.conv import MessagePassing
11
+ from torch_geometric.data import HeteroData, Data
12
+ import torch_geometric.transforms as T
13
+ from torch_geometric.nn import LightGCN
14
+ import utils
15
+
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ data = torch.load("processed_MVL_light.pt")
18
+ ch = torch.load('./lightGCNModel_num_layers_MovieLens100K_checkpoint.pt')
19
+ lightGCNModel = LightGCN(
20
+ num_nodes=data.num_nodes,
21
+ embedding_dim=64,
22
+ num_layers=3,
23
+ ).to(device)
24
+ optimizer = torch.optim.Adam(lightGCNModel.parameters(), lr=0.005)
25
+ mask_train = data.edge_index[0] < data.edge_index[1]
26
+ train_edge_label_index = data.edge_index[:, mask_train]
27
+ lightGCNModel.load_state_dict(ch['model_state_dict'])
28
+ optimizer.load_state_dict(ch['optimizer_state_dict'])
29
+ num_items = 1682
30
+ num_users = 943
31
+
32
+ def recommend(user_id):
33
+
34
+ ground_truth_items, recommendations = utils.predict(lightGCNModel, device, data, num_users, num_items, user_id, train_edge_label_index, k=5)
35
+ return ground_truth_items, recommendations
36
+
37
+ iface = gr.Interface(fn=recommend, inputs="number", outputs=["text", "text"])
38
+ iface.launch()