Viraj45 commited on
Commit
0a529d3
·
verified ·
1 Parent(s): 51274d4
Files changed (6) hide show
  1. app.py +102 -0
  2. item_encoder.pkl +3 -0
  3. model.py +45 -0
  4. requirements.txt +183 -0
  5. user_encoder.pkl +3 -0
  6. user_positive_items.pkl +3 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import torch
4
+ import pickle
5
+ import gradio as gr
6
+ from model import NCFModel
7
+ import numpy as np
8
+
9
+ # Load encoders and user_positive_items
10
+ with open('user_encoder.pkl', 'rb') as f:
11
+ user_encoder = pickle.load(f)
12
+
13
+ with open('item_encoder.pkl', 'rb') as f:
14
+ item_encoder = pickle.load(f)
15
+
16
+ with open('user_positive_items.pkl', 'rb') as f:
17
+ user_positive_items = pickle.load(f)
18
+
19
+ # Load the trained model
20
+ class NCFModelWrapper:
21
+ def __init__(self, model_path, num_users, num_items, embedding_size=50, device='cpu'):
22
+ self.device = torch.device(device)
23
+ self.model = NCFModel(num_users, num_items, embedding_size=embedding_size).to(self.device)
24
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
25
+ self.model.eval()
26
+
27
+ def predict(self, user, item):
28
+ with torch.no_grad():
29
+ user = torch.tensor([user], dtype=torch.long).to(self.device)
30
+ item = torch.tensor([item], dtype=torch.long).to(self.device)
31
+ output = self.model(user, item)
32
+ score = torch.sigmoid(output).item()
33
+ return score
34
+
35
+ # Determine number of users and items from encoders
36
+ num_users = len(user_encoder.classes_)
37
+ num_items = len(item_encoder.classes_)
38
+
39
+ # Initialize the model
40
+ model = NCFModelWrapper(
41
+ model_path='best_ncf_model.pth',
42
+ num_users=num_users,
43
+ num_items=num_items,
44
+ embedding_size=50, # Ensure this matches your trained model
45
+ device='cpu' # Change to 'cuda' if GPU is available and desired
46
+ )
47
+
48
+ def recommend(user_id, num_recommendations=5):
49
+ """
50
+ Given a user ID, recommend top N items.
51
+ """
52
+ try:
53
+ user = user_encoder.transform([user_id])[0]
54
+ except:
55
+ return f"User ID '{user_id}' not found."
56
+
57
+ # Get items the user has interacted with
58
+ pos_items = user_positive_items.get(user, set())
59
+
60
+ # Get all possible items
61
+ all_items = set(range(num_items))
62
+
63
+ # Candidate items are those not interacted with
64
+ candidate_items = list(all_items - pos_items)
65
+
66
+ # Predict scores for candidate items
67
+ scores = []
68
+ for item in candidate_items:
69
+ score = model.predict(user, item)
70
+ scores.append((item, score))
71
+
72
+ # Sort items based on score
73
+ scores.sort(key=lambda x: x[1], reverse=True)
74
+
75
+ # Get top N recommendations
76
+ top_items = scores[:num_recommendations]
77
+ recommendations = []
78
+ for item_id, score in top_items:
79
+ original_item_id = item_encoder.inverse_transform([item_id])[0]
80
+ recommendations.append(f"Item ID: {original_item_id} (Score: {score:.4f})")
81
+
82
+ return "\n".join(recommendations)
83
+
84
+ # Define Gradio interface
85
+ iface = gr.Interface(
86
+ fn=recommend,
87
+ inputs=[
88
+ gr.inputs.Textbox(lines=1, placeholder="Enter User ID", label="User ID"),
89
+ gr.inputs.Slider(minimum=1, maximum=20, step=1, default=5, label="Number of Recommendations")
90
+ ],
91
+ outputs="text",
92
+ title="Neural Collaborative Filtering Recommendation System",
93
+ description="Enter a User ID to receive personalized item recommendations.",
94
+ examples=[
95
+ ["user_1", 5],
96
+ ["user_2", 10],
97
+ ["user_3", 7]
98
+ ]
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ iface.launch()
item_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d3794f0651bde4879ca75543c83892f98cfc80e5c59c1a01b866bf15ec1144b
3
+ size 2058499
model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class NCFModel(nn.Module):
7
+ def __init__(self, num_users, num_items, embedding_size=50):
8
+ """
9
+ Initialize the NCF model with embedding layers and fully connected layers.
10
+
11
+ Args:
12
+ num_users (int): Total number of unique users.
13
+ num_items (int): Total number of unique items.
14
+ embedding_size (int): Size of the embedding vectors.
15
+ """
16
+ super(NCFModel, self).__init__()
17
+ self.user_embedding = nn.Embedding(num_users, embedding_size)
18
+ self.item_embedding = nn.Embedding(num_items, embedding_size)
19
+
20
+ self.fc1 = nn.Linear(embedding_size * 2, 128)
21
+ self.dropout1 = nn.Dropout(0.5)
22
+ self.fc2 = nn.Linear(128, 64)
23
+ self.dropout2 = nn.Dropout(0.5)
24
+ self.output_layer = nn.Linear(64, 1)
25
+
26
+ def forward(self, user, item):
27
+ """
28
+ Forward pass through the model.
29
+
30
+ Args:
31
+ user (torch.LongTensor): Tensor of user IDs.
32
+ item (torch.LongTensor): Tensor of item IDs.
33
+
34
+ Returns:
35
+ torch.Tensor: Output logits indicating interaction likelihood.
36
+ """
37
+ user_emb = self.user_embedding(user)
38
+ item_emb = self.item_embedding(item)
39
+ x = torch.cat([user_emb, item_emb], dim=1)
40
+ x = torch.relu(self.fc1(x))
41
+ x = self.dropout1(x)
42
+ x = torch.relu(self.fc2(x))
43
+ x = self.dropout2(x)
44
+ x = self.output_layer(x) # No sigmoid here; handled in loss function
45
+ return x.squeeze()
requirements.txt ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.9
4
+ aiosignal==1.3.1
5
+ anyio==4.5.0
6
+ argon2-cffi==23.1.0
7
+ argon2-cffi-bindings==21.2.0
8
+ arrow==1.3.0
9
+ asttokens==2.4.1
10
+ astunparse==1.6.3
11
+ async-lru==2.0.4
12
+ async-timeout==4.0.3
13
+ attrs==24.2.0
14
+ babel==2.16.0
15
+ backcall==0.2.0
16
+ beautifulsoup4==4.12.3
17
+ bleach==6.1.0
18
+ Brotli @ file:///C:/b/abs_3d36mno480/croot/brotli-split_1714483178642/work
19
+ cachetools==5.5.0
20
+ certifi @ file:///C:/b/abs_1fw_exq1si/croot/certifi_1725551736618/work/certifi
21
+ cffi==1.17.1
22
+ charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
23
+ colorama==0.4.6
24
+ comm==0.2.2
25
+ contourpy==1.1.1
26
+ cycler==0.12.1
27
+ datasets==3.0.1
28
+ debugpy==1.8.6
29
+ decorator==5.1.1
30
+ defusedxml==0.7.1
31
+ dill==0.3.8
32
+ exceptiongroup==1.2.2
33
+ executing==2.1.0
34
+ fastjsonschema==2.20.0
35
+ filelock @ file:///C:/b/abs_f2gie28u58/croot/filelock_1700591233643/work
36
+ flatbuffers==24.3.25
37
+ fonttools==4.54.1
38
+ fqdn==1.5.1
39
+ frozenlist==1.4.1
40
+ fsspec==2024.6.1
41
+ gast==0.4.0
42
+ gmpy2 @ file:///C:/ci/gmpy2_1645456279018/work
43
+ google-auth==2.35.0
44
+ google-auth-oauthlib==1.0.0
45
+ google-pasta==0.2.0
46
+ grpcio==1.66.2
47
+ h11==0.14.0
48
+ h5py==3.11.0
49
+ httpcore==1.0.6
50
+ httpx==0.27.2
51
+ huggingface-hub==0.25.1
52
+ idna==3.10
53
+ importlib_metadata==8.5.0
54
+ importlib_resources==6.4.5
55
+ ipykernel==6.29.5
56
+ ipython==8.12.3
57
+ isoduration==20.11.0
58
+ jedi==0.19.1
59
+ Jinja2 @ file:///C:/b/abs_92fccttino/croot/jinja2_1716993447201/work
60
+ joblib==1.4.2
61
+ json5==0.9.25
62
+ jsonpointer==3.0.0
63
+ jsonschema==4.23.0
64
+ jsonschema-specifications==2023.12.1
65
+ jupyter-events==0.10.0
66
+ jupyter-lsp==2.2.5
67
+ jupyter_client==8.6.3
68
+ jupyter_core==5.7.2
69
+ jupyter_server==2.14.2
70
+ jupyter_server_terminals==0.5.3
71
+ jupyterlab==4.2.5
72
+ jupyterlab_pygments==0.3.0
73
+ jupyterlab_server==2.27.3
74
+ kaggle==1.6.17
75
+ keras==2.13.1
76
+ kiwisolver==1.4.7
77
+ libclang==18.1.1
78
+ Markdown==3.7
79
+ MarkupSafe @ file:///C:/b/abs_ecfdqh67b_/croot/markupsafe_1704206030535/work
80
+ matplotlib==3.7.5
81
+ matplotlib-inline==0.1.7
82
+ mistune==3.0.2
83
+ mkl-fft @ file:///C:/b/abs_19i1y8ykas/croot/mkl_fft_1695058226480/work
84
+ mkl-random @ file:///C:/b/abs_edwkj1_o69/croot/mkl_random_1695059866750/work
85
+ mkl-service==2.4.0
86
+ mpmath @ file:///C:/b/abs_7833jrbiox/croot/mpmath_1690848321154/work
87
+ multidict==6.1.0
88
+ multiprocess==0.70.16
89
+ nbclient==0.10.0
90
+ nbconvert==7.16.4
91
+ nbformat==5.10.4
92
+ nest-asyncio==1.6.0
93
+ networkx @ file:///C:/b/abs_e6gi1go5op/croot/networkx_1690562046966/work
94
+ notebook_shim==0.2.4
95
+ numpy @ file:///C:/Users/dev-admin/mkl/numpy_and_numpy_base_1682982345978/work
96
+ oauthlib==3.2.2
97
+ opt_einsum==3.4.0
98
+ overrides==7.7.0
99
+ packaging==24.1
100
+ pandas==2.0.3
101
+ pandocfilters==1.5.1
102
+ parso==0.8.4
103
+ pickleshare==0.7.5
104
+ pillow @ file:///C:/b/abs_32o8er3uqp/croot/pillow_1721059447598/work
105
+ pkgutil_resolve_name==1.3.10
106
+ platformdirs==4.3.6
107
+ prometheus_client==0.21.0
108
+ prompt_toolkit==3.0.48
109
+ protobuf==4.25.5
110
+ psutil==6.0.0
111
+ pure_eval==0.2.3
112
+ pyarrow==17.0.0
113
+ pyasn1==0.6.1
114
+ pyasn1_modules==0.4.1
115
+ pycparser==2.22
116
+ Pygments==2.18.0
117
+ pyparsing==3.1.4
118
+ PySocks @ file:///C:/ci/pysocks_1605287845585/work
119
+ python-dateutil==2.9.0.post0
120
+ python-json-logger==2.0.7
121
+ python-slugify==8.0.4
122
+ pytz==2024.2
123
+ pywin32==307
124
+ pywinpty==2.0.13
125
+ PyYAML @ file:///C:/b/abs_782o3mbw7z/croot/pyyaml_1698096085010/work
126
+ pyzmq==26.2.0
127
+ referencing==0.35.1
128
+ regex==2024.9.11
129
+ requests @ file:///C:/b/abs_9frifg92q2/croot/requests_1721410901096/work
130
+ requests-oauthlib==2.0.0
131
+ rfc3339-validator==0.1.4
132
+ rfc3986-validator==0.1.1
133
+ rpds-py==0.20.0
134
+ rsa==4.9
135
+ safetensors==0.4.5
136
+ scikit-learn==1.3.2
137
+ scipy==1.10.1
138
+ seaborn==0.13.2
139
+ Send2Trash==1.8.3
140
+ six==1.16.0
141
+ sniffio==1.3.1
142
+ soupsieve==2.6
143
+ stack-data==0.6.3
144
+ sympy @ file:///C:/b/abs_4e4p71hdj_/croot/sympy_1724938208509/work
145
+ tensorboard==2.13.0
146
+ tensorboard-data-server==0.7.2
147
+ tensorflow==2.13.0
148
+ tensorflow-estimator==2.13.0
149
+ tensorflow-intel==2.13.0
150
+ tensorflow-io-gcs-filesystem==0.31.0
151
+ termcolor==2.4.0
152
+ terminado==0.18.1
153
+ text-unidecode==1.3
154
+ threadpoolctl==3.5.0
155
+ tinycss2==1.3.0
156
+ tokenizers==0.20.0
157
+ tomli==2.0.2
158
+ torch==2.4.1+cu118
159
+ torchaudio==2.4.1+cu118
160
+ torchvision==0.19.1+cu118
161
+ tornado==6.4.1
162
+ tqdm==4.66.5
163
+ traitlets==5.14.3
164
+ transformers==4.45.1
165
+ types-python-dateutil==2.9.0.20241003
166
+ typing_extensions @ file:///C:/b/abs_0as9mdbkfl/croot/typing_extensions_1715268906610/work
167
+ tzdata==2024.2
168
+ uri-template==1.3.0
169
+ urllib3 @ file:///C:/b/abs_9a_f8h_bn2/croot/urllib3_1727769836930/work
170
+ wcwidth==0.2.13
171
+ webcolors==24.8.0
172
+ webencodings==0.5.1
173
+ websocket-client==1.8.0
174
+ Werkzeug==3.0.4
175
+ win-inet-pton @ file:///C:/ci/win_inet_pton_1605306167264/work
176
+ wrapt==1.16.0
177
+ xxhash==3.5.0
178
+ yarl==1.13.1
179
+ zipp==3.20.2
180
+ torch==2.0.1
181
+ gradio==3.35.0
182
+ scikit-learn==1.2.2
183
+ numpy==1.25.2
user_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e864f77c5736419d426f902fc721d8357e6ea80e3af1f1ad0db8a7d205e43c89
3
+ size 12938253
user_positive_items.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7b2dca91fb75790d1660abbda79f1e201aabfb3a90b157417345eb009c42ad
3
+ size 18454612