Spaces:
Sleeping
Sleeping
updated
Browse files- app.py +102 -0
- item_encoder.pkl +3 -0
- model.py +45 -0
- requirements.txt +183 -0
- user_encoder.pkl +3 -0
- user_positive_items.pkl +3 -0
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import pickle
|
5 |
+
import gradio as gr
|
6 |
+
from model import NCFModel
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# Load encoders and user_positive_items
|
10 |
+
with open('user_encoder.pkl', 'rb') as f:
|
11 |
+
user_encoder = pickle.load(f)
|
12 |
+
|
13 |
+
with open('item_encoder.pkl', 'rb') as f:
|
14 |
+
item_encoder = pickle.load(f)
|
15 |
+
|
16 |
+
with open('user_positive_items.pkl', 'rb') as f:
|
17 |
+
user_positive_items = pickle.load(f)
|
18 |
+
|
19 |
+
# Load the trained model
|
20 |
+
class NCFModelWrapper:
|
21 |
+
def __init__(self, model_path, num_users, num_items, embedding_size=50, device='cpu'):
|
22 |
+
self.device = torch.device(device)
|
23 |
+
self.model = NCFModel(num_users, num_items, embedding_size=embedding_size).to(self.device)
|
24 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
25 |
+
self.model.eval()
|
26 |
+
|
27 |
+
def predict(self, user, item):
|
28 |
+
with torch.no_grad():
|
29 |
+
user = torch.tensor([user], dtype=torch.long).to(self.device)
|
30 |
+
item = torch.tensor([item], dtype=torch.long).to(self.device)
|
31 |
+
output = self.model(user, item)
|
32 |
+
score = torch.sigmoid(output).item()
|
33 |
+
return score
|
34 |
+
|
35 |
+
# Determine number of users and items from encoders
|
36 |
+
num_users = len(user_encoder.classes_)
|
37 |
+
num_items = len(item_encoder.classes_)
|
38 |
+
|
39 |
+
# Initialize the model
|
40 |
+
model = NCFModelWrapper(
|
41 |
+
model_path='best_ncf_model.pth',
|
42 |
+
num_users=num_users,
|
43 |
+
num_items=num_items,
|
44 |
+
embedding_size=50, # Ensure this matches your trained model
|
45 |
+
device='cpu' # Change to 'cuda' if GPU is available and desired
|
46 |
+
)
|
47 |
+
|
48 |
+
def recommend(user_id, num_recommendations=5):
|
49 |
+
"""
|
50 |
+
Given a user ID, recommend top N items.
|
51 |
+
"""
|
52 |
+
try:
|
53 |
+
user = user_encoder.transform([user_id])[0]
|
54 |
+
except:
|
55 |
+
return f"User ID '{user_id}' not found."
|
56 |
+
|
57 |
+
# Get items the user has interacted with
|
58 |
+
pos_items = user_positive_items.get(user, set())
|
59 |
+
|
60 |
+
# Get all possible items
|
61 |
+
all_items = set(range(num_items))
|
62 |
+
|
63 |
+
# Candidate items are those not interacted with
|
64 |
+
candidate_items = list(all_items - pos_items)
|
65 |
+
|
66 |
+
# Predict scores for candidate items
|
67 |
+
scores = []
|
68 |
+
for item in candidate_items:
|
69 |
+
score = model.predict(user, item)
|
70 |
+
scores.append((item, score))
|
71 |
+
|
72 |
+
# Sort items based on score
|
73 |
+
scores.sort(key=lambda x: x[1], reverse=True)
|
74 |
+
|
75 |
+
# Get top N recommendations
|
76 |
+
top_items = scores[:num_recommendations]
|
77 |
+
recommendations = []
|
78 |
+
for item_id, score in top_items:
|
79 |
+
original_item_id = item_encoder.inverse_transform([item_id])[0]
|
80 |
+
recommendations.append(f"Item ID: {original_item_id} (Score: {score:.4f})")
|
81 |
+
|
82 |
+
return "\n".join(recommendations)
|
83 |
+
|
84 |
+
# Define Gradio interface
|
85 |
+
iface = gr.Interface(
|
86 |
+
fn=recommend,
|
87 |
+
inputs=[
|
88 |
+
gr.inputs.Textbox(lines=1, placeholder="Enter User ID", label="User ID"),
|
89 |
+
gr.inputs.Slider(minimum=1, maximum=20, step=1, default=5, label="Number of Recommendations")
|
90 |
+
],
|
91 |
+
outputs="text",
|
92 |
+
title="Neural Collaborative Filtering Recommendation System",
|
93 |
+
description="Enter a User ID to receive personalized item recommendations.",
|
94 |
+
examples=[
|
95 |
+
["user_1", 5],
|
96 |
+
["user_2", 10],
|
97 |
+
["user_3", 7]
|
98 |
+
]
|
99 |
+
)
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
iface.launch()
|
item_encoder.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d3794f0651bde4879ca75543c83892f98cfc80e5c59c1a01b866bf15ec1144b
|
3 |
+
size 2058499
|
model.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/model.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
class NCFModel(nn.Module):
|
7 |
+
def __init__(self, num_users, num_items, embedding_size=50):
|
8 |
+
"""
|
9 |
+
Initialize the NCF model with embedding layers and fully connected layers.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
num_users (int): Total number of unique users.
|
13 |
+
num_items (int): Total number of unique items.
|
14 |
+
embedding_size (int): Size of the embedding vectors.
|
15 |
+
"""
|
16 |
+
super(NCFModel, self).__init__()
|
17 |
+
self.user_embedding = nn.Embedding(num_users, embedding_size)
|
18 |
+
self.item_embedding = nn.Embedding(num_items, embedding_size)
|
19 |
+
|
20 |
+
self.fc1 = nn.Linear(embedding_size * 2, 128)
|
21 |
+
self.dropout1 = nn.Dropout(0.5)
|
22 |
+
self.fc2 = nn.Linear(128, 64)
|
23 |
+
self.dropout2 = nn.Dropout(0.5)
|
24 |
+
self.output_layer = nn.Linear(64, 1)
|
25 |
+
|
26 |
+
def forward(self, user, item):
|
27 |
+
"""
|
28 |
+
Forward pass through the model.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
user (torch.LongTensor): Tensor of user IDs.
|
32 |
+
item (torch.LongTensor): Tensor of item IDs.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
torch.Tensor: Output logits indicating interaction likelihood.
|
36 |
+
"""
|
37 |
+
user_emb = self.user_embedding(user)
|
38 |
+
item_emb = self.item_embedding(item)
|
39 |
+
x = torch.cat([user_emb, item_emb], dim=1)
|
40 |
+
x = torch.relu(self.fc1(x))
|
41 |
+
x = self.dropout1(x)
|
42 |
+
x = torch.relu(self.fc2(x))
|
43 |
+
x = self.dropout2(x)
|
44 |
+
x = self.output_layer(x) # No sigmoid here; handled in loss function
|
45 |
+
return x.squeeze()
|
requirements.txt
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
aiohappyeyeballs==2.4.3
|
3 |
+
aiohttp==3.10.9
|
4 |
+
aiosignal==1.3.1
|
5 |
+
anyio==4.5.0
|
6 |
+
argon2-cffi==23.1.0
|
7 |
+
argon2-cffi-bindings==21.2.0
|
8 |
+
arrow==1.3.0
|
9 |
+
asttokens==2.4.1
|
10 |
+
astunparse==1.6.3
|
11 |
+
async-lru==2.0.4
|
12 |
+
async-timeout==4.0.3
|
13 |
+
attrs==24.2.0
|
14 |
+
babel==2.16.0
|
15 |
+
backcall==0.2.0
|
16 |
+
beautifulsoup4==4.12.3
|
17 |
+
bleach==6.1.0
|
18 |
+
Brotli @ file:///C:/b/abs_3d36mno480/croot/brotli-split_1714483178642/work
|
19 |
+
cachetools==5.5.0
|
20 |
+
certifi @ file:///C:/b/abs_1fw_exq1si/croot/certifi_1725551736618/work/certifi
|
21 |
+
cffi==1.17.1
|
22 |
+
charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
|
23 |
+
colorama==0.4.6
|
24 |
+
comm==0.2.2
|
25 |
+
contourpy==1.1.1
|
26 |
+
cycler==0.12.1
|
27 |
+
datasets==3.0.1
|
28 |
+
debugpy==1.8.6
|
29 |
+
decorator==5.1.1
|
30 |
+
defusedxml==0.7.1
|
31 |
+
dill==0.3.8
|
32 |
+
exceptiongroup==1.2.2
|
33 |
+
executing==2.1.0
|
34 |
+
fastjsonschema==2.20.0
|
35 |
+
filelock @ file:///C:/b/abs_f2gie28u58/croot/filelock_1700591233643/work
|
36 |
+
flatbuffers==24.3.25
|
37 |
+
fonttools==4.54.1
|
38 |
+
fqdn==1.5.1
|
39 |
+
frozenlist==1.4.1
|
40 |
+
fsspec==2024.6.1
|
41 |
+
gast==0.4.0
|
42 |
+
gmpy2 @ file:///C:/ci/gmpy2_1645456279018/work
|
43 |
+
google-auth==2.35.0
|
44 |
+
google-auth-oauthlib==1.0.0
|
45 |
+
google-pasta==0.2.0
|
46 |
+
grpcio==1.66.2
|
47 |
+
h11==0.14.0
|
48 |
+
h5py==3.11.0
|
49 |
+
httpcore==1.0.6
|
50 |
+
httpx==0.27.2
|
51 |
+
huggingface-hub==0.25.1
|
52 |
+
idna==3.10
|
53 |
+
importlib_metadata==8.5.0
|
54 |
+
importlib_resources==6.4.5
|
55 |
+
ipykernel==6.29.5
|
56 |
+
ipython==8.12.3
|
57 |
+
isoduration==20.11.0
|
58 |
+
jedi==0.19.1
|
59 |
+
Jinja2 @ file:///C:/b/abs_92fccttino/croot/jinja2_1716993447201/work
|
60 |
+
joblib==1.4.2
|
61 |
+
json5==0.9.25
|
62 |
+
jsonpointer==3.0.0
|
63 |
+
jsonschema==4.23.0
|
64 |
+
jsonschema-specifications==2023.12.1
|
65 |
+
jupyter-events==0.10.0
|
66 |
+
jupyter-lsp==2.2.5
|
67 |
+
jupyter_client==8.6.3
|
68 |
+
jupyter_core==5.7.2
|
69 |
+
jupyter_server==2.14.2
|
70 |
+
jupyter_server_terminals==0.5.3
|
71 |
+
jupyterlab==4.2.5
|
72 |
+
jupyterlab_pygments==0.3.0
|
73 |
+
jupyterlab_server==2.27.3
|
74 |
+
kaggle==1.6.17
|
75 |
+
keras==2.13.1
|
76 |
+
kiwisolver==1.4.7
|
77 |
+
libclang==18.1.1
|
78 |
+
Markdown==3.7
|
79 |
+
MarkupSafe @ file:///C:/b/abs_ecfdqh67b_/croot/markupsafe_1704206030535/work
|
80 |
+
matplotlib==3.7.5
|
81 |
+
matplotlib-inline==0.1.7
|
82 |
+
mistune==3.0.2
|
83 |
+
mkl-fft @ file:///C:/b/abs_19i1y8ykas/croot/mkl_fft_1695058226480/work
|
84 |
+
mkl-random @ file:///C:/b/abs_edwkj1_o69/croot/mkl_random_1695059866750/work
|
85 |
+
mkl-service==2.4.0
|
86 |
+
mpmath @ file:///C:/b/abs_7833jrbiox/croot/mpmath_1690848321154/work
|
87 |
+
multidict==6.1.0
|
88 |
+
multiprocess==0.70.16
|
89 |
+
nbclient==0.10.0
|
90 |
+
nbconvert==7.16.4
|
91 |
+
nbformat==5.10.4
|
92 |
+
nest-asyncio==1.6.0
|
93 |
+
networkx @ file:///C:/b/abs_e6gi1go5op/croot/networkx_1690562046966/work
|
94 |
+
notebook_shim==0.2.4
|
95 |
+
numpy @ file:///C:/Users/dev-admin/mkl/numpy_and_numpy_base_1682982345978/work
|
96 |
+
oauthlib==3.2.2
|
97 |
+
opt_einsum==3.4.0
|
98 |
+
overrides==7.7.0
|
99 |
+
packaging==24.1
|
100 |
+
pandas==2.0.3
|
101 |
+
pandocfilters==1.5.1
|
102 |
+
parso==0.8.4
|
103 |
+
pickleshare==0.7.5
|
104 |
+
pillow @ file:///C:/b/abs_32o8er3uqp/croot/pillow_1721059447598/work
|
105 |
+
pkgutil_resolve_name==1.3.10
|
106 |
+
platformdirs==4.3.6
|
107 |
+
prometheus_client==0.21.0
|
108 |
+
prompt_toolkit==3.0.48
|
109 |
+
protobuf==4.25.5
|
110 |
+
psutil==6.0.0
|
111 |
+
pure_eval==0.2.3
|
112 |
+
pyarrow==17.0.0
|
113 |
+
pyasn1==0.6.1
|
114 |
+
pyasn1_modules==0.4.1
|
115 |
+
pycparser==2.22
|
116 |
+
Pygments==2.18.0
|
117 |
+
pyparsing==3.1.4
|
118 |
+
PySocks @ file:///C:/ci/pysocks_1605287845585/work
|
119 |
+
python-dateutil==2.9.0.post0
|
120 |
+
python-json-logger==2.0.7
|
121 |
+
python-slugify==8.0.4
|
122 |
+
pytz==2024.2
|
123 |
+
pywin32==307
|
124 |
+
pywinpty==2.0.13
|
125 |
+
PyYAML @ file:///C:/b/abs_782o3mbw7z/croot/pyyaml_1698096085010/work
|
126 |
+
pyzmq==26.2.0
|
127 |
+
referencing==0.35.1
|
128 |
+
regex==2024.9.11
|
129 |
+
requests @ file:///C:/b/abs_9frifg92q2/croot/requests_1721410901096/work
|
130 |
+
requests-oauthlib==2.0.0
|
131 |
+
rfc3339-validator==0.1.4
|
132 |
+
rfc3986-validator==0.1.1
|
133 |
+
rpds-py==0.20.0
|
134 |
+
rsa==4.9
|
135 |
+
safetensors==0.4.5
|
136 |
+
scikit-learn==1.3.2
|
137 |
+
scipy==1.10.1
|
138 |
+
seaborn==0.13.2
|
139 |
+
Send2Trash==1.8.3
|
140 |
+
six==1.16.0
|
141 |
+
sniffio==1.3.1
|
142 |
+
soupsieve==2.6
|
143 |
+
stack-data==0.6.3
|
144 |
+
sympy @ file:///C:/b/abs_4e4p71hdj_/croot/sympy_1724938208509/work
|
145 |
+
tensorboard==2.13.0
|
146 |
+
tensorboard-data-server==0.7.2
|
147 |
+
tensorflow==2.13.0
|
148 |
+
tensorflow-estimator==2.13.0
|
149 |
+
tensorflow-intel==2.13.0
|
150 |
+
tensorflow-io-gcs-filesystem==0.31.0
|
151 |
+
termcolor==2.4.0
|
152 |
+
terminado==0.18.1
|
153 |
+
text-unidecode==1.3
|
154 |
+
threadpoolctl==3.5.0
|
155 |
+
tinycss2==1.3.0
|
156 |
+
tokenizers==0.20.0
|
157 |
+
tomli==2.0.2
|
158 |
+
torch==2.4.1+cu118
|
159 |
+
torchaudio==2.4.1+cu118
|
160 |
+
torchvision==0.19.1+cu118
|
161 |
+
tornado==6.4.1
|
162 |
+
tqdm==4.66.5
|
163 |
+
traitlets==5.14.3
|
164 |
+
transformers==4.45.1
|
165 |
+
types-python-dateutil==2.9.0.20241003
|
166 |
+
typing_extensions @ file:///C:/b/abs_0as9mdbkfl/croot/typing_extensions_1715268906610/work
|
167 |
+
tzdata==2024.2
|
168 |
+
uri-template==1.3.0
|
169 |
+
urllib3 @ file:///C:/b/abs_9a_f8h_bn2/croot/urllib3_1727769836930/work
|
170 |
+
wcwidth==0.2.13
|
171 |
+
webcolors==24.8.0
|
172 |
+
webencodings==0.5.1
|
173 |
+
websocket-client==1.8.0
|
174 |
+
Werkzeug==3.0.4
|
175 |
+
win-inet-pton @ file:///C:/ci/win_inet_pton_1605306167264/work
|
176 |
+
wrapt==1.16.0
|
177 |
+
xxhash==3.5.0
|
178 |
+
yarl==1.13.1
|
179 |
+
zipp==3.20.2
|
180 |
+
torch==2.0.1
|
181 |
+
gradio==3.35.0
|
182 |
+
scikit-learn==1.2.2
|
183 |
+
numpy==1.25.2
|
user_encoder.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e864f77c5736419d426f902fc721d8357e6ea80e3af1f1ad0db8a7d205e43c89
|
3 |
+
size 12938253
|
user_positive_items.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e7b2dca91fb75790d1660abbda79f1e201aabfb3a90b157417345eb009c42ad
|
3 |
+
size 18454612
|