Joash2024 commited on
Commit
3714bf4
·
1 Parent(s): b7e4e4f

training code

Browse files
Files changed (2) hide show
  1. src/app.py → app.py +0 -0
  2. src/model_training_v2.ipynb +1226 -0
src/app.py → app.py RENAMED
File without changes
src/model_training_v2.ipynb ADDED
@@ -0,0 +1,1226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "# **Music recommender**"
21
+ ],
22
+ "metadata": {
23
+ "id": "DDADPl-phDUC"
24
+ }
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "source": [
29
+ "# **Load Data**"
30
+ ],
31
+ "metadata": {
32
+ "id": "E7Cu5Fmqct7J"
33
+ }
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {
39
+ "colab": {
40
+ "base_uri": "https://localhost:8080/",
41
+ "height": 540
42
+ },
43
+ "id": "bI8bNavbajsv",
44
+ "outputId": "7cba8b5d-4a63-433f-be3c-87ce794833ba"
45
+ },
46
+ "outputs": [
47
+ {
48
+ "output_type": "display_data",
49
+ "data": {
50
+ "text/plain": [
51
+ "<IPython.core.display.HTML object>"
52
+ ],
53
+ "text/html": [
54
+ "\n",
55
+ " <input type=\"file\" id=\"files-793c32c8-99a6-4873-9585-738e1d4b2ab1\" name=\"files[]\" multiple disabled\n",
56
+ " style=\"border:none\" />\n",
57
+ " <output id=\"result-793c32c8-99a6-4873-9585-738e1d4b2ab1\">\n",
58
+ " Upload widget is only available when the cell has been executed in the\n",
59
+ " current browser session. Please rerun this cell to enable.\n",
60
+ " </output>\n",
61
+ " <script>// Copyright 2017 Google LLC\n",
62
+ "//\n",
63
+ "// Licensed under the Apache License, Version 2.0 (the \"License\");\n",
64
+ "// you may not use this file except in compliance with the License.\n",
65
+ "// You may obtain a copy of the License at\n",
66
+ "//\n",
67
+ "// http://www.apache.org/licenses/LICENSE-2.0\n",
68
+ "//\n",
69
+ "// Unless required by applicable law or agreed to in writing, software\n",
70
+ "// distributed under the License is distributed on an \"AS IS\" BASIS,\n",
71
+ "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
72
+ "// See the License for the specific language governing permissions and\n",
73
+ "// limitations under the License.\n",
74
+ "\n",
75
+ "/**\n",
76
+ " * @fileoverview Helpers for google.colab Python module.\n",
77
+ " */\n",
78
+ "(function(scope) {\n",
79
+ "function span(text, styleAttributes = {}) {\n",
80
+ " const element = document.createElement('span');\n",
81
+ " element.textContent = text;\n",
82
+ " for (const key of Object.keys(styleAttributes)) {\n",
83
+ " element.style[key] = styleAttributes[key];\n",
84
+ " }\n",
85
+ " return element;\n",
86
+ "}\n",
87
+ "\n",
88
+ "// Max number of bytes which will be uploaded at a time.\n",
89
+ "const MAX_PAYLOAD_SIZE = 100 * 1024;\n",
90
+ "\n",
91
+ "function _uploadFiles(inputId, outputId) {\n",
92
+ " const steps = uploadFilesStep(inputId, outputId);\n",
93
+ " const outputElement = document.getElementById(outputId);\n",
94
+ " // Cache steps on the outputElement to make it available for the next call\n",
95
+ " // to uploadFilesContinue from Python.\n",
96
+ " outputElement.steps = steps;\n",
97
+ "\n",
98
+ " return _uploadFilesContinue(outputId);\n",
99
+ "}\n",
100
+ "\n",
101
+ "// This is roughly an async generator (not supported in the browser yet),\n",
102
+ "// where there are multiple asynchronous steps and the Python side is going\n",
103
+ "// to poll for completion of each step.\n",
104
+ "// This uses a Promise to block the python side on completion of each step,\n",
105
+ "// then passes the result of the previous step as the input to the next step.\n",
106
+ "function _uploadFilesContinue(outputId) {\n",
107
+ " const outputElement = document.getElementById(outputId);\n",
108
+ " const steps = outputElement.steps;\n",
109
+ "\n",
110
+ " const next = steps.next(outputElement.lastPromiseValue);\n",
111
+ " return Promise.resolve(next.value.promise).then((value) => {\n",
112
+ " // Cache the last promise value to make it available to the next\n",
113
+ " // step of the generator.\n",
114
+ " outputElement.lastPromiseValue = value;\n",
115
+ " return next.value.response;\n",
116
+ " });\n",
117
+ "}\n",
118
+ "\n",
119
+ "/**\n",
120
+ " * Generator function which is called between each async step of the upload\n",
121
+ " * process.\n",
122
+ " * @param {string} inputId Element ID of the input file picker element.\n",
123
+ " * @param {string} outputId Element ID of the output display.\n",
124
+ " * @return {!Iterable<!Object>} Iterable of next steps.\n",
125
+ " */\n",
126
+ "function* uploadFilesStep(inputId, outputId) {\n",
127
+ " const inputElement = document.getElementById(inputId);\n",
128
+ " inputElement.disabled = false;\n",
129
+ "\n",
130
+ " const outputElement = document.getElementById(outputId);\n",
131
+ " outputElement.innerHTML = '';\n",
132
+ "\n",
133
+ " const pickedPromise = new Promise((resolve) => {\n",
134
+ " inputElement.addEventListener('change', (e) => {\n",
135
+ " resolve(e.target.files);\n",
136
+ " });\n",
137
+ " });\n",
138
+ "\n",
139
+ " const cancel = document.createElement('button');\n",
140
+ " inputElement.parentElement.appendChild(cancel);\n",
141
+ " cancel.textContent = 'Cancel upload';\n",
142
+ " const cancelPromise = new Promise((resolve) => {\n",
143
+ " cancel.onclick = () => {\n",
144
+ " resolve(null);\n",
145
+ " };\n",
146
+ " });\n",
147
+ "\n",
148
+ " // Wait for the user to pick the files.\n",
149
+ " const files = yield {\n",
150
+ " promise: Promise.race([pickedPromise, cancelPromise]),\n",
151
+ " response: {\n",
152
+ " action: 'starting',\n",
153
+ " }\n",
154
+ " };\n",
155
+ "\n",
156
+ " cancel.remove();\n",
157
+ "\n",
158
+ " // Disable the input element since further picks are not allowed.\n",
159
+ " inputElement.disabled = true;\n",
160
+ "\n",
161
+ " if (!files) {\n",
162
+ " return {\n",
163
+ " response: {\n",
164
+ " action: 'complete',\n",
165
+ " }\n",
166
+ " };\n",
167
+ " }\n",
168
+ "\n",
169
+ " for (const file of files) {\n",
170
+ " const li = document.createElement('li');\n",
171
+ " li.append(span(file.name, {fontWeight: 'bold'}));\n",
172
+ " li.append(span(\n",
173
+ " `(${file.type || 'n/a'}) - ${file.size} bytes, ` +\n",
174
+ " `last modified: ${\n",
175
+ " file.lastModifiedDate ? file.lastModifiedDate.toLocaleDateString() :\n",
176
+ " 'n/a'} - `));\n",
177
+ " const percent = span('0% done');\n",
178
+ " li.appendChild(percent);\n",
179
+ "\n",
180
+ " outputElement.appendChild(li);\n",
181
+ "\n",
182
+ " const fileDataPromise = new Promise((resolve) => {\n",
183
+ " const reader = new FileReader();\n",
184
+ " reader.onload = (e) => {\n",
185
+ " resolve(e.target.result);\n",
186
+ " };\n",
187
+ " reader.readAsArrayBuffer(file);\n",
188
+ " });\n",
189
+ " // Wait for the data to be ready.\n",
190
+ " let fileData = yield {\n",
191
+ " promise: fileDataPromise,\n",
192
+ " response: {\n",
193
+ " action: 'continue',\n",
194
+ " }\n",
195
+ " };\n",
196
+ "\n",
197
+ " // Use a chunked sending to avoid message size limits. See b/62115660.\n",
198
+ " let position = 0;\n",
199
+ " do {\n",
200
+ " const length = Math.min(fileData.byteLength - position, MAX_PAYLOAD_SIZE);\n",
201
+ " const chunk = new Uint8Array(fileData, position, length);\n",
202
+ " position += length;\n",
203
+ "\n",
204
+ " const base64 = btoa(String.fromCharCode.apply(null, chunk));\n",
205
+ " yield {\n",
206
+ " response: {\n",
207
+ " action: 'append',\n",
208
+ " file: file.name,\n",
209
+ " data: base64,\n",
210
+ " },\n",
211
+ " };\n",
212
+ "\n",
213
+ " let percentDone = fileData.byteLength === 0 ?\n",
214
+ " 100 :\n",
215
+ " Math.round((position / fileData.byteLength) * 100);\n",
216
+ " percent.textContent = `${percentDone}% done`;\n",
217
+ "\n",
218
+ " } while (position < fileData.byteLength);\n",
219
+ " }\n",
220
+ "\n",
221
+ " // All done.\n",
222
+ " yield {\n",
223
+ " response: {\n",
224
+ " action: 'complete',\n",
225
+ " }\n",
226
+ " };\n",
227
+ "}\n",
228
+ "\n",
229
+ "scope.google = scope.google || {};\n",
230
+ "scope.google.colab = scope.google.colab || {};\n",
231
+ "scope.google.colab._files = {\n",
232
+ " _uploadFiles,\n",
233
+ " _uploadFilesContinue,\n",
234
+ "};\n",
235
+ "})(self);\n",
236
+ "</script> "
237
+ ]
238
+ },
239
+ "metadata": {}
240
+ },
241
+ {
242
+ "output_type": "stream",
243
+ "name": "stdout",
244
+ "text": [
245
+ "Saving music_data.csv to music_data.csv\n",
246
+ " title \\\n",
247
+ "0 100 Club 1996 ''We Love You Beatles'' - Live \n",
248
+ "1 Yo Quiero Contigo \n",
249
+ "4 Emerald \n",
250
+ "6 Karma \n",
251
+ "7 Money Blues \n",
252
+ "\n",
253
+ " release artist_name duration \\\n",
254
+ "0 Sex Pistols - The Interviews Sex Pistols 88.73751 \n",
255
+ "1 Sentenciados - Platinum Edition Baby Rasta & Gringo 167.36608 \n",
256
+ "4 Emerald Bedrock 501.86404 \n",
257
+ "6 The Diary Of Alicia Keys Alicia Keys 255.99955 \n",
258
+ "7 Slidetime Joanna Connor 243.66975 \n",
259
+ "\n",
260
+ " artist_familiarity artist_hotttnesss year listeners playcount \\\n",
261
+ "0 0.731184 0.549204 0 172 210 \n",
262
+ "1 0.610186 0.355320 0 9753 16911 \n",
263
+ "4 0.654039 0.390625 2004 973 2247 \n",
264
+ "6 0.933916 0.778674 2003 250304 1028356 \n",
265
+ "7 0.479218 0.332857 0 429 1008 \n",
266
+ "\n",
267
+ " tags \n",
268
+ "0 The Beatles, title is a full sentence \n",
269
+ "1 Reggaeton, alexis y fido, Eliana, mis videos, ... \n",
270
+ "4 dance \n",
271
+ "6 rnb, soul, Alicia Keys, female vocalists, Karma \n",
272
+ "7 guitar girl, blues \n"
273
+ ]
274
+ }
275
+ ],
276
+ "source": [
277
+ "import pandas as pd\n",
278
+ "from google.colab import files\n",
279
+ "\n",
280
+ "# Upload the file\n",
281
+ "uploaded = files.upload()\n",
282
+ "\n",
283
+ "# Assuming the file is named \"music_data.csv\"\n",
284
+ "data_path = \"music_data.csv\"\n",
285
+ "\n",
286
+ "# Load the data\n",
287
+ "df = pd.read_csv(data_path)\n",
288
+ "df.dropna(inplace=True)\n",
289
+ "\n",
290
+ "# Display the first few rows of the dataset\n",
291
+ "print(df.head())\n"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "source": [
297
+ "df.head()"
298
+ ],
299
+ "metadata": {
300
+ "colab": {
301
+ "base_uri": "https://localhost:8080/",
302
+ "height": 206
303
+ },
304
+ "id": "9E3in0U3dK5I",
305
+ "outputId": "c1d5362a-6a33-4543-ff4d-4e11cf8220ec"
306
+ },
307
+ "execution_count": null,
308
+ "outputs": [
309
+ {
310
+ "output_type": "execute_result",
311
+ "data": {
312
+ "text/plain": [
313
+ " title \\\n",
314
+ "0 100 Club 1996 ''We Love You Beatles'' - Live \n",
315
+ "1 Yo Quiero Contigo \n",
316
+ "4 Emerald \n",
317
+ "6 Karma \n",
318
+ "7 Money Blues \n",
319
+ "\n",
320
+ " release artist_name duration \\\n",
321
+ "0 Sex Pistols - The Interviews Sex Pistols 88.73751 \n",
322
+ "1 Sentenciados - Platinum Edition Baby Rasta & Gringo 167.36608 \n",
323
+ "4 Emerald Bedrock 501.86404 \n",
324
+ "6 The Diary Of Alicia Keys Alicia Keys 255.99955 \n",
325
+ "7 Slidetime Joanna Connor 243.66975 \n",
326
+ "\n",
327
+ " artist_familiarity artist_hotttnesss year listeners playcount \\\n",
328
+ "0 0.731184 0.549204 0 172 210 \n",
329
+ "1 0.610186 0.355320 0 9753 16911 \n",
330
+ "4 0.654039 0.390625 2004 973 2247 \n",
331
+ "6 0.933916 0.778674 2003 250304 1028356 \n",
332
+ "7 0.479218 0.332857 0 429 1008 \n",
333
+ "\n",
334
+ " tags \n",
335
+ "0 The Beatles, title is a full sentence \n",
336
+ "1 Reggaeton, alexis y fido, Eliana, mis videos, ... \n",
337
+ "4 dance \n",
338
+ "6 rnb, soul, Alicia Keys, female vocalists, Karma \n",
339
+ "7 guitar girl, blues "
340
+ ],
341
+ "text/html": [
342
+ "\n",
343
+ " <div id=\"df-b9e5c35d-1534-4ad7-8661-887b39a472e9\" class=\"colab-df-container\">\n",
344
+ " <div>\n",
345
+ "<style scoped>\n",
346
+ " .dataframe tbody tr th:only-of-type {\n",
347
+ " vertical-align: middle;\n",
348
+ " }\n",
349
+ "\n",
350
+ " .dataframe tbody tr th {\n",
351
+ " vertical-align: top;\n",
352
+ " }\n",
353
+ "\n",
354
+ " .dataframe thead th {\n",
355
+ " text-align: right;\n",
356
+ " }\n",
357
+ "</style>\n",
358
+ "<table border=\"1\" class=\"dataframe\">\n",
359
+ " <thead>\n",
360
+ " <tr style=\"text-align: right;\">\n",
361
+ " <th></th>\n",
362
+ " <th>title</th>\n",
363
+ " <th>release</th>\n",
364
+ " <th>artist_name</th>\n",
365
+ " <th>duration</th>\n",
366
+ " <th>artist_familiarity</th>\n",
367
+ " <th>artist_hotttnesss</th>\n",
368
+ " <th>year</th>\n",
369
+ " <th>listeners</th>\n",
370
+ " <th>playcount</th>\n",
371
+ " <th>tags</th>\n",
372
+ " </tr>\n",
373
+ " </thead>\n",
374
+ " <tbody>\n",
375
+ " <tr>\n",
376
+ " <th>0</th>\n",
377
+ " <td>100 Club 1996 ''We Love You Beatles'' - Live</td>\n",
378
+ " <td>Sex Pistols - The Interviews</td>\n",
379
+ " <td>Sex Pistols</td>\n",
380
+ " <td>88.73751</td>\n",
381
+ " <td>0.731184</td>\n",
382
+ " <td>0.549204</td>\n",
383
+ " <td>0</td>\n",
384
+ " <td>172</td>\n",
385
+ " <td>210</td>\n",
386
+ " <td>The Beatles, title is a full sentence</td>\n",
387
+ " </tr>\n",
388
+ " <tr>\n",
389
+ " <th>1</th>\n",
390
+ " <td>Yo Quiero Contigo</td>\n",
391
+ " <td>Sentenciados - Platinum Edition</td>\n",
392
+ " <td>Baby Rasta &amp; Gringo</td>\n",
393
+ " <td>167.36608</td>\n",
394
+ " <td>0.610186</td>\n",
395
+ " <td>0.355320</td>\n",
396
+ " <td>0</td>\n",
397
+ " <td>9753</td>\n",
398
+ " <td>16911</td>\n",
399
+ " <td>Reggaeton, alexis y fido, Eliana, mis videos, ...</td>\n",
400
+ " </tr>\n",
401
+ " <tr>\n",
402
+ " <th>4</th>\n",
403
+ " <td>Emerald</td>\n",
404
+ " <td>Emerald</td>\n",
405
+ " <td>Bedrock</td>\n",
406
+ " <td>501.86404</td>\n",
407
+ " <td>0.654039</td>\n",
408
+ " <td>0.390625</td>\n",
409
+ " <td>2004</td>\n",
410
+ " <td>973</td>\n",
411
+ " <td>2247</td>\n",
412
+ " <td>dance</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <th>6</th>\n",
416
+ " <td>Karma</td>\n",
417
+ " <td>The Diary Of Alicia Keys</td>\n",
418
+ " <td>Alicia Keys</td>\n",
419
+ " <td>255.99955</td>\n",
420
+ " <td>0.933916</td>\n",
421
+ " <td>0.778674</td>\n",
422
+ " <td>2003</td>\n",
423
+ " <td>250304</td>\n",
424
+ " <td>1028356</td>\n",
425
+ " <td>rnb, soul, Alicia Keys, female vocalists, Karma</td>\n",
426
+ " </tr>\n",
427
+ " <tr>\n",
428
+ " <th>7</th>\n",
429
+ " <td>Money Blues</td>\n",
430
+ " <td>Slidetime</td>\n",
431
+ " <td>Joanna Connor</td>\n",
432
+ " <td>243.66975</td>\n",
433
+ " <td>0.479218</td>\n",
434
+ " <td>0.332857</td>\n",
435
+ " <td>0</td>\n",
436
+ " <td>429</td>\n",
437
+ " <td>1008</td>\n",
438
+ " <td>guitar girl, blues</td>\n",
439
+ " </tr>\n",
440
+ " </tbody>\n",
441
+ "</table>\n",
442
+ "</div>\n",
443
+ " <div class=\"colab-df-buttons\">\n",
444
+ "\n",
445
+ " <div class=\"colab-df-container\">\n",
446
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b9e5c35d-1534-4ad7-8661-887b39a472e9')\"\n",
447
+ " title=\"Convert this dataframe to an interactive table.\"\n",
448
+ " style=\"display:none;\">\n",
449
+ "\n",
450
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
451
+ " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
452
+ " </svg>\n",
453
+ " </button>\n",
454
+ "\n",
455
+ " <style>\n",
456
+ " .colab-df-container {\n",
457
+ " display:flex;\n",
458
+ " gap: 12px;\n",
459
+ " }\n",
460
+ "\n",
461
+ " .colab-df-convert {\n",
462
+ " background-color: #E8F0FE;\n",
463
+ " border: none;\n",
464
+ " border-radius: 50%;\n",
465
+ " cursor: pointer;\n",
466
+ " display: none;\n",
467
+ " fill: #1967D2;\n",
468
+ " height: 32px;\n",
469
+ " padding: 0 0 0 0;\n",
470
+ " width: 32px;\n",
471
+ " }\n",
472
+ "\n",
473
+ " .colab-df-convert:hover {\n",
474
+ " background-color: #E2EBFA;\n",
475
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
476
+ " fill: #174EA6;\n",
477
+ " }\n",
478
+ "\n",
479
+ " .colab-df-buttons div {\n",
480
+ " margin-bottom: 4px;\n",
481
+ " }\n",
482
+ "\n",
483
+ " [theme=dark] .colab-df-convert {\n",
484
+ " background-color: #3B4455;\n",
485
+ " fill: #D2E3FC;\n",
486
+ " }\n",
487
+ "\n",
488
+ " [theme=dark] .colab-df-convert:hover {\n",
489
+ " background-color: #434B5C;\n",
490
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
491
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
492
+ " fill: #FFFFFF;\n",
493
+ " }\n",
494
+ " </style>\n",
495
+ "\n",
496
+ " <script>\n",
497
+ " const buttonEl =\n",
498
+ " document.querySelector('#df-b9e5c35d-1534-4ad7-8661-887b39a472e9 button.colab-df-convert');\n",
499
+ " buttonEl.style.display =\n",
500
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
501
+ "\n",
502
+ " async function convertToInteractive(key) {\n",
503
+ " const element = document.querySelector('#df-b9e5c35d-1534-4ad7-8661-887b39a472e9');\n",
504
+ " const dataTable =\n",
505
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
506
+ " [key], {});\n",
507
+ " if (!dataTable) return;\n",
508
+ "\n",
509
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
510
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
511
+ " + ' to learn more about interactive tables.';\n",
512
+ " element.innerHTML = '';\n",
513
+ " dataTable['output_type'] = 'display_data';\n",
514
+ " await google.colab.output.renderOutput(dataTable, element);\n",
515
+ " const docLink = document.createElement('div');\n",
516
+ " docLink.innerHTML = docLinkHtml;\n",
517
+ " element.appendChild(docLink);\n",
518
+ " }\n",
519
+ " </script>\n",
520
+ " </div>\n",
521
+ "\n",
522
+ "\n",
523
+ "<div id=\"df-3ffda883-e826-470a-8413-bc736b2d9130\">\n",
524
+ " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-3ffda883-e826-470a-8413-bc736b2d9130')\"\n",
525
+ " title=\"Suggest charts\"\n",
526
+ " style=\"display:none;\">\n",
527
+ "\n",
528
+ "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
529
+ " width=\"24px\">\n",
530
+ " <g>\n",
531
+ " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
532
+ " </g>\n",
533
+ "</svg>\n",
534
+ " </button>\n",
535
+ "\n",
536
+ "<style>\n",
537
+ " .colab-df-quickchart {\n",
538
+ " --bg-color: #E8F0FE;\n",
539
+ " --fill-color: #1967D2;\n",
540
+ " --hover-bg-color: #E2EBFA;\n",
541
+ " --hover-fill-color: #174EA6;\n",
542
+ " --disabled-fill-color: #AAA;\n",
543
+ " --disabled-bg-color: #DDD;\n",
544
+ " }\n",
545
+ "\n",
546
+ " [theme=dark] .colab-df-quickchart {\n",
547
+ " --bg-color: #3B4455;\n",
548
+ " --fill-color: #D2E3FC;\n",
549
+ " --hover-bg-color: #434B5C;\n",
550
+ " --hover-fill-color: #FFFFFF;\n",
551
+ " --disabled-bg-color: #3B4455;\n",
552
+ " --disabled-fill-color: #666;\n",
553
+ " }\n",
554
+ "\n",
555
+ " .colab-df-quickchart {\n",
556
+ " background-color: var(--bg-color);\n",
557
+ " border: none;\n",
558
+ " border-radius: 50%;\n",
559
+ " cursor: pointer;\n",
560
+ " display: none;\n",
561
+ " fill: var(--fill-color);\n",
562
+ " height: 32px;\n",
563
+ " padding: 0;\n",
564
+ " width: 32px;\n",
565
+ " }\n",
566
+ "\n",
567
+ " .colab-df-quickchart:hover {\n",
568
+ " background-color: var(--hover-bg-color);\n",
569
+ " box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
570
+ " fill: var(--button-hover-fill-color);\n",
571
+ " }\n",
572
+ "\n",
573
+ " .colab-df-quickchart-complete:disabled,\n",
574
+ " .colab-df-quickchart-complete:disabled:hover {\n",
575
+ " background-color: var(--disabled-bg-color);\n",
576
+ " fill: var(--disabled-fill-color);\n",
577
+ " box-shadow: none;\n",
578
+ " }\n",
579
+ "\n",
580
+ " .colab-df-spinner {\n",
581
+ " border: 2px solid var(--fill-color);\n",
582
+ " border-color: transparent;\n",
583
+ " border-bottom-color: var(--fill-color);\n",
584
+ " animation:\n",
585
+ " spin 1s steps(1) infinite;\n",
586
+ " }\n",
587
+ "\n",
588
+ " @keyframes spin {\n",
589
+ " 0% {\n",
590
+ " border-color: transparent;\n",
591
+ " border-bottom-color: var(--fill-color);\n",
592
+ " border-left-color: var(--fill-color);\n",
593
+ " }\n",
594
+ " 20% {\n",
595
+ " border-color: transparent;\n",
596
+ " border-left-color: var(--fill-color);\n",
597
+ " border-top-color: var(--fill-color);\n",
598
+ " }\n",
599
+ " 30% {\n",
600
+ " border-color: transparent;\n",
601
+ " border-left-color: var(--fill-color);\n",
602
+ " border-top-color: var(--fill-color);\n",
603
+ " border-right-color: var(--fill-color);\n",
604
+ " }\n",
605
+ " 40% {\n",
606
+ " border-color: transparent;\n",
607
+ " border-right-color: var(--fill-color);\n",
608
+ " border-top-color: var(--fill-color);\n",
609
+ " }\n",
610
+ " 60% {\n",
611
+ " border-color: transparent;\n",
612
+ " border-right-color: var(--fill-color);\n",
613
+ " }\n",
614
+ " 80% {\n",
615
+ " border-color: transparent;\n",
616
+ " border-right-color: var(--fill-color);\n",
617
+ " border-bottom-color: var(--fill-color);\n",
618
+ " }\n",
619
+ " 90% {\n",
620
+ " border-color: transparent;\n",
621
+ " border-bottom-color: var(--fill-color);\n",
622
+ " }\n",
623
+ " }\n",
624
+ "</style>\n",
625
+ "\n",
626
+ " <script>\n",
627
+ " async function quickchart(key) {\n",
628
+ " const quickchartButtonEl =\n",
629
+ " document.querySelector('#' + key + ' button');\n",
630
+ " quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
631
+ " quickchartButtonEl.classList.add('colab-df-spinner');\n",
632
+ " try {\n",
633
+ " const charts = await google.colab.kernel.invokeFunction(\n",
634
+ " 'suggestCharts', [key], {});\n",
635
+ " } catch (error) {\n",
636
+ " console.error('Error during call to suggestCharts:', error);\n",
637
+ " }\n",
638
+ " quickchartButtonEl.classList.remove('colab-df-spinner');\n",
639
+ " quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
640
+ " }\n",
641
+ " (() => {\n",
642
+ " let quickchartButtonEl =\n",
643
+ " document.querySelector('#df-3ffda883-e826-470a-8413-bc736b2d9130 button');\n",
644
+ " quickchartButtonEl.style.display =\n",
645
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
646
+ " })();\n",
647
+ " </script>\n",
648
+ "</div>\n",
649
+ "\n",
650
+ " </div>\n",
651
+ " </div>\n"
652
+ ],
653
+ "application/vnd.google.colaboratory.intrinsic+json": {
654
+ "type": "dataframe",
655
+ "variable_name": "df",
656
+ "summary": "{\n \"name\": \"df\",\n \"rows\": 5063,\n \"fields\": [\n {\n \"column\": \"title\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 4854,\n \"samples\": [\n \"I Wish I Had A Girl\",\n \"Jump [Jacques Lu Cont Edit]\",\n \"Mulin' Around\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"release\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 4187,\n \"samples\": [\n \"Le Bordel Magnifique\",\n \"Charlotte's Web (OST)\",\n \"X.O. Experience\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"artist_name\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2461,\n \"samples\": [\n \"Lee Ritenour\",\n \"Pennywise\",\n \"Anneli Drecker\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"duration\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 107.73289375974717,\n \"min\": 1.04444,\n \"max\": 1815.2224,\n \"num_unique_values\": 3939,\n \"samples\": [\n 294.24281,\n 240.79628,\n 115.53914\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"artist_familiarity\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.14886096792686204,\n \"min\": 0.0,\n \"max\": 1.0,\n \"num_unique_values\": 2474,\n \"samples\": [\n 0.787098355481,\n 0.481771820142,\n 0.374024633035\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"artist_hotttnesss\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.1347303774485448,\n \"min\": 0.0,\n \"max\": 1.08250255673,\n \"num_unique_values\": 2398,\n \"samples\": [\n 0.376018761952,\n 0.355667956383,\n 0.289970666912\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"year\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 917,\n \"min\": 0,\n \"max\": 2010,\n \"num_unique_values\": 69,\n \"samples\": [\n 1979,\n 0,\n 1965\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"listeners\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 150513,\n \"min\": 0,\n \"max\": 2451482,\n \"num_unique_values\": 3914,\n \"samples\": [\n 781546,\n 6216,\n 396579\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"playcount\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1115103,\n \"min\": 0,\n \"max\": 23182516,\n \"num_unique_values\": 4422,\n \"samples\": [\n 62736,\n 1305,\n 17033\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"tags\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 4583,\n \"samples\": [\n \"dance, 90s, trance, House, jungle\",\n \"country, favorite songs, classic country, linedance, Martina McBride\",\n \"90s, heavy metal, thrash metal, metal, punk\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
657
+ }
658
+ },
659
+ "metadata": {},
660
+ "execution_count": 2
661
+ }
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "source": [
667
+ "# Display basic information about the dataset\n",
668
+ "print(df.info())\n",
669
+ "\n",
670
+ "# Display summary statistics for numerical columns\n",
671
+ "print(df.describe())\n",
672
+ "\n",
673
+ "# Display unique values for categorical columns\n",
674
+ "print(\"Unique values in 'title':\", df['title'].nunique())\n",
675
+ "print(\"Unique values in 'artist_name':\", df['artist_name'].nunique())\n",
676
+ "print(\"Unique values in 'tags':\", df['tags'].nunique())"
677
+ ],
678
+ "metadata": {
679
+ "colab": {
680
+ "base_uri": "https://localhost:8080/"
681
+ },
682
+ "id": "b_sSacbdHcn6",
683
+ "outputId": "f745b028-fd97-4b19-b9f0-9e041621e5d3"
684
+ },
685
+ "execution_count": null,
686
+ "outputs": [
687
+ {
688
+ "output_type": "stream",
689
+ "name": "stdout",
690
+ "text": [
691
+ "<class 'pandas.core.frame.DataFrame'>\n",
692
+ "Index: 5063 entries, 0 to 9530\n",
693
+ "Data columns (total 10 columns):\n",
694
+ " # Column Non-Null Count Dtype \n",
695
+ "--- ------ -------------- ----- \n",
696
+ " 0 title 5063 non-null object \n",
697
+ " 1 release 5063 non-null object \n",
698
+ " 2 artist_name 5063 non-null object \n",
699
+ " 3 duration 5063 non-null float64\n",
700
+ " 4 artist_familiarity 5063 non-null float64\n",
701
+ " 5 artist_hotttnesss 5063 non-null float64\n",
702
+ " 6 year 5063 non-null int64 \n",
703
+ " 7 listeners 5063 non-null int64 \n",
704
+ " 8 playcount 5063 non-null int64 \n",
705
+ " 9 tags 5063 non-null object \n",
706
+ "dtypes: float64(3), int64(3), object(4)\n",
707
+ "memory usage: 435.1+ KB\n",
708
+ "None\n",
709
+ " duration artist_familiarity artist_hotttnesss year \\\n",
710
+ "count 5063.000000 5063.000000 5063.000000 5063.000000 \n",
711
+ "mean 243.156073 0.626861 0.439664 1392.483705 \n",
712
+ "std 107.732894 0.148861 0.134730 917.360336 \n",
713
+ "min 1.044440 0.000000 0.000000 0.000000 \n",
714
+ "25% 183.535870 0.527033 0.363132 0.000000 \n",
715
+ "50% 229.145670 0.619531 0.417819 1993.000000 \n",
716
+ "75% 280.920365 0.731184 0.510325 2004.000000 \n",
717
+ "max 1815.222400 1.000000 1.082503 2010.000000 \n",
718
+ "\n",
719
+ " listeners playcount \n",
720
+ "count 5.063000e+03 5.063000e+03 \n",
721
+ "mean 4.526352e+04 2.622274e+05 \n",
722
+ "std 1.505135e+05 1.115104e+06 \n",
723
+ "min 0.000000e+00 0.000000e+00 \n",
724
+ "25% 7.545000e+02 1.894500e+03 \n",
725
+ "50% 3.387000e+03 9.439000e+03 \n",
726
+ "75% 1.787350e+04 6.269500e+04 \n",
727
+ "max 2.451482e+06 2.318252e+07 \n",
728
+ "Unique values in 'title': 4854\n",
729
+ "Unique values in 'artist_name': 2461\n",
730
+ "Unique values in 'tags': 4583\n"
731
+ ]
732
+ }
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "markdown",
737
+ "source": [
738
+ "# **Preprocessing**"
739
+ ],
740
+ "metadata": {
741
+ "id": "wPVFDtk9g9ox"
742
+ }
743
+ },
744
+ {
745
+ "cell_type": "code",
746
+ "source": [
747
+ "import pandas as pd\n",
748
+ "from sklearn.preprocessing import LabelEncoder, MinMaxScaler\n",
749
+ "import joblib\n",
750
+ "import re\n",
751
+ "\n",
752
+ "# Function to clean tags and artist names\n",
753
+ "def clean_text(text):\n",
754
+ " # Convert to lowercase\n",
755
+ " text = text.lower()\n",
756
+ " # Remove special characters and digits\n",
757
+ " text = re.sub(r'[^a-zA-Z\\s]', '', text)\n",
758
+ " # Remove extra white spaces\n",
759
+ " text = re.sub(r'\\s+', ' ', text).strip()\n",
760
+ " return text\n",
761
+ "\n",
762
+ "# Clean 'tags' and 'artist_name' columns\n",
763
+ "df['tags'] = df['tags'].apply(clean_text)\n",
764
+ "df['artist_name'] = df['artist_name'].apply(clean_text)\n",
765
+ "\n",
766
+ "def label_encode_data(df):\n",
767
+ " df = df.copy(deep=True)\n",
768
+ " label_encoders = {}\n",
769
+ " unknown_label = 'unknown' # Define an unknown label\n",
770
+ "\n",
771
+ " for column in ['tags', 'title', 'artist_name']:\n",
772
+ " le = LabelEncoder()\n",
773
+ " unique_categories = df[column].unique().tolist()\n",
774
+ " unique_categories.append(unknown_label)\n",
775
+ " le.fit(unique_categories)\n",
776
+ " df[column] = le.transform(df[column].astype(str))\n",
777
+ " label_encoders[column] = le\n",
778
+ "\n",
779
+ " return df, label_encoders\n",
780
+ "\n",
781
+ "# Normalize numerical features\n",
782
+ "scaler = MinMaxScaler()\n",
783
+ "df[['listeners', 'playcount']] = scaler.fit_transform(df[['listeners', 'playcount']])\n",
784
+ "\n",
785
+ "# Label encode categorical features\n",
786
+ "df_scaled, label_encoders = label_encode_data(df)\n",
787
+ "\n",
788
+ "# Save the encoders and scaler\n",
789
+ "joblib.dump(label_encoders, \"/content/new_label_encoders.joblib\")\n",
790
+ "joblib.dump(scaler, \"/content/new_scaler.joblib\")\n",
791
+ "\n",
792
+ "print(\"Label encoders and scaler saved successfully.\")\n"
793
+ ],
794
+ "metadata": {
795
+ "colab": {
796
+ "base_uri": "https://localhost:8080/"
797
+ },
798
+ "id": "3fsU1IvylyZg",
799
+ "outputId": "c2ba3adc-c077-454a-94de-ca9bb0ba4807"
800
+ },
801
+ "execution_count": null,
802
+ "outputs": [
803
+ {
804
+ "output_type": "stream",
805
+ "name": "stdout",
806
+ "text": [
807
+ "Label encoders and scaler saved successfully.\n"
808
+ ]
809
+ }
810
+ ]
811
+ },
812
+ {
813
+ "cell_type": "code",
814
+ "source": [
815
+ "from sklearn.model_selection import train_test_split\n",
816
+ "\n",
817
+ "# Split data into features and target\n",
818
+ "X = df_scaled[['tags', 'artist_name']]\n",
819
+ "y = df_scaled['title']\n",
820
+ "\n",
821
+ "# Split the dataset into training and testing sets\n",
822
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
823
+ "print(\"Data split into training and testing sets.\")\n",
824
+ "\n",
825
+ "# Number of unique titles\n",
826
+ "num_unique_titles = len(label_encoders['title'].classes_)\n",
827
+ "\n",
828
+ "# Check for out-of-bounds indices in y_train and y_test\n",
829
+ "print(\"Maximum value in y_train:\", y_train.max())\n",
830
+ "print(\"Maximum value in y_test:\", y_test.max())\n",
831
+ "print(\"Number of unique titles:\", num_unique_titles)\n",
832
+ "\n",
833
+ "# If any out-of-bounds values are found, print them\n",
834
+ "out_of_bounds_train = y_train[y_train >= num_unique_titles]\n",
835
+ "out_of_bounds_test = y_test[y_test >= num_unique_titles]\n",
836
+ "\n",
837
+ "if not out_of_bounds_train.empty:\n",
838
+ " print(\"Out-of-bounds values in y_train:\", out_of_bounds_train)\n",
839
+ "if not out_of_bounds_test.empty:\n",
840
+ " print(\"Out-of-bounds values in y_test:\", out_of_bounds_test)\n",
841
+ "\n",
842
+ "# Fix out-of-bounds values by setting them to a valid index\n",
843
+ "y_train = y_train.clip(upper=num_unique_titles - 1)\n",
844
+ "y_test = y_test.clip(upper=num_unique_titles - 1)\n",
845
+ "\n",
846
+ "# Print the maximum values after clipping\n",
847
+ "print(\"Maximum value in y_train after clipping:\", y_train.max())\n",
848
+ "print(\"Maximum value in y_test after clipping:\", y_test.max())\n"
849
+ ],
850
+ "metadata": {
851
+ "colab": {
852
+ "base_uri": "https://localhost:8080/"
853
+ },
854
+ "id": "JBWZWp_8Jr82",
855
+ "outputId": "73a312c1-3615-4a87-965b-c2fc41fc50e7"
856
+ },
857
+ "execution_count": null,
858
+ "outputs": [
859
+ {
860
+ "output_type": "stream",
861
+ "name": "stdout",
862
+ "text": [
863
+ "Data split into training and testing sets.\n",
864
+ "Maximum value in y_train: 4854\n",
865
+ "Maximum value in y_test: 4850\n",
866
+ "Number of unique titles: 4855\n",
867
+ "Maximum value in y_train after clipping: 4854\n",
868
+ "Maximum value in y_test after clipping: 4850\n"
869
+ ]
870
+ }
871
+ ]
872
+ },
873
+ {
874
+ "cell_type": "markdown",
875
+ "source": [
876
+ "# **Training**"
877
+ ],
878
+ "metadata": {
879
+ "id": "syYhdUbxgA-K"
880
+ }
881
+ },
882
+ {
883
+ "cell_type": "code",
884
+ "source": [
885
+ "import torch\n",
886
+ "import torch.nn as nn\n",
887
+ "import torch.optim as optim\n",
888
+ "from torch.utils.data import DataLoader\n",
889
+ "import numpy as np\n",
890
+ "\n",
891
+ "# Define the neural network model with Dropout and Batch Normalization\n",
892
+ "class ImprovedSongRecommender(nn.Module):\n",
893
+ " def __init__(self, input_size, num_titles):\n",
894
+ " super(ImprovedSongRecommender, self).__init__()\n",
895
+ " self.fc1 = nn.Linear(input_size, 128)\n",
896
+ " self.bn1 = nn.BatchNorm1d(128)\n",
897
+ " self.fc2 = nn.Linear(128, 256)\n",
898
+ " self.bn2 = nn.BatchNorm1d(256)\n",
899
+ " self.fc3 = nn.Linear(256, 128)\n",
900
+ " self.bn3 = nn.BatchNorm1d(128)\n",
901
+ " self.output = nn.Linear(128, num_titles)\n",
902
+ " self.dropout = nn.Dropout(0.5)\n",
903
+ "\n",
904
+ " def forward(self, x):\n",
905
+ " x = torch.relu(self.bn1(self.fc1(x)))\n",
906
+ " x = self.dropout(x)\n",
907
+ " x = torch.relu(self.bn2(self.fc2(x)))\n",
908
+ " x = self.dropout(x)\n",
909
+ " x = torch.relu(self.bn3(self.fc3(x)))\n",
910
+ " x = self.dropout(x)\n",
911
+ " x = self.output(x)\n",
912
+ " return x\n",
913
+ "\n",
914
+ "# Adjusting input size for the model\n",
915
+ "input_size = X_train.shape[1] # Number of features in the input\n",
916
+ "num_unique_titles = len(label_encoders['title'].classes_) # Number of unique titles including 'unknown'\n",
917
+ "\n",
918
+ "# Initialize the model with the correct input size and output size\n",
919
+ "model = ImprovedSongRecommender(input_size, num_unique_titles)\n",
920
+ "\n",
921
+ "# Initialize the optimizer and loss function\n",
922
+ "optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)\n",
923
+ "criterion = nn.CrossEntropyLoss()\n",
924
+ "\n",
925
+ "# Use a learning rate scheduler\n",
926
+ "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)\n",
927
+ "\n",
928
+ "# Early stopping parameters\n",
929
+ "patience = 3\n",
930
+ "min_delta = 0.01\n",
931
+ "best_val_loss = np.inf\n",
932
+ "patience_counter = 0\n",
933
+ "\n",
934
+ "# Function to train the model\n",
935
+ "def train_model(model, X_train, y_train, X_test, y_test):\n",
936
+ " global best_val_loss, patience_counter\n",
937
+ " train_loader = DataLoader(list(zip(X_train.values.astype(float), y_train)), batch_size=10, shuffle=True)\n",
938
+ " test_loader = DataLoader(list(zip(X_test.values.astype(float), y_test)), batch_size=10, shuffle=False)\n",
939
+ "\n",
940
+ " model.train()\n",
941
+ " for epoch in range(20): # Increase the number of epochs\n",
942
+ " train_loss = 0\n",
943
+ " for features, labels in train_loader:\n",
944
+ " optimizer.zero_grad()\n",
945
+ " outputs = model(features.float())\n",
946
+ " loss = criterion(outputs, labels.long())\n",
947
+ " loss.backward()\n",
948
+ " optimizer.step()\n",
949
+ " train_loss += loss.item()\n",
950
+ "\n",
951
+ " # Step the scheduler\n",
952
+ " scheduler.step()\n",
953
+ "\n",
954
+ " # Validation phase\n",
955
+ " model.eval()\n",
956
+ " validation_loss = 0\n",
957
+ " with torch.no_grad():\n",
958
+ " for features, labels in test_loader:\n",
959
+ " outputs = model(features.float())\n",
960
+ " loss = criterion(outputs, labels.long())\n",
961
+ " validation_loss += loss.item()\n",
962
+ "\n",
963
+ " avg_val_loss = validation_loss / len(test_loader)\n",
964
+ " print(f'Epoch {epoch+1}, Training Loss: {train_loss / len(train_loader)}, Validation Loss: {avg_val_loss}')\n",
965
+ "\n",
966
+ " # Early stopping\n",
967
+ " if avg_val_loss < best_val_loss - min_delta:\n",
968
+ " best_val_loss = avg_val_loss\n",
969
+ " patience_counter = 0\n",
970
+ " else:\n",
971
+ " patience_counter += 1\n",
972
+ " if patience_counter >= patience:\n",
973
+ " print(\"Early stopping triggered\")\n",
974
+ " break\n",
975
+ "\n",
976
+ "# Train the model\n",
977
+ "train_model(model, X_train, y_train, X_test, y_test)\n",
978
+ "\n",
979
+ "# Save the trained model\n",
980
+ "model_path = '/content/improved_model.pth'\n",
981
+ "torch.save(model.state_dict(), model_path)\n",
982
+ "\n",
983
+ "print(\"Improved model trained and saved successfully.\")\n"
984
+ ],
985
+ "metadata": {
986
+ "colab": {
987
+ "base_uri": "https://localhost:8080/"
988
+ },
989
+ "id": "aaR1IGymKQq2",
990
+ "outputId": "9e5115a5-1a75-4672-a0b3-4fdd314e1a79"
991
+ },
992
+ "execution_count": null,
993
+ "outputs": [
994
+ {
995
+ "output_type": "stream",
996
+ "name": "stdout",
997
+ "text": [
998
+ "Epoch 1, Training Loss: 8.921830113728841, Validation Loss: 8.836441385979747\n",
999
+ "Epoch 2, Training Loss: 8.331391870239635, Validation Loss: 9.148561271966672\n",
1000
+ "Epoch 3, Training Loss: 7.494005516429007, Validation Loss: 10.484928570541681\n",
1001
+ "Epoch 4, Training Loss: 6.704833826606657, Validation Loss: 11.745069999320835\n",
1002
+ "Early stopping triggered\n",
1003
+ "Improved model trained and saved successfully.\n"
1004
+ ]
1005
+ }
1006
+ ]
1007
+ },
1008
+ {
1009
+ "cell_type": "markdown",
1010
+ "source": [
1011
+ "# **Testing**"
1012
+ ],
1013
+ "metadata": {
1014
+ "id": "g4hJVlNXf5Vu"
1015
+ }
1016
+ },
1017
+ {
1018
+ "cell_type": "code",
1019
+ "source": [
1020
+ "import torch\n",
1021
+ "from joblib import load\n",
1022
+ "\n",
1023
+ "# Define the same neural network model\n",
1024
+ "class ImprovedSongRecommender(nn.Module):\n",
1025
+ " def __init__(self, input_size, num_titles):\n",
1026
+ " super(ImprovedSongRecommender, self).__init__()\n",
1027
+ " self.fc1 = nn.Linear(input_size, 128)\n",
1028
+ " self.bn1 = nn.BatchNorm1d(128)\n",
1029
+ " self.fc2 = nn.Linear(128, 256)\n",
1030
+ " self.bn2 = nn.BatchNorm1d(256)\n",
1031
+ " self.fc3 = nn.Linear(256, 128)\n",
1032
+ " self.bn3 = nn.BatchNorm1d(128)\n",
1033
+ " self.output = nn.Linear(128, num_titles)\n",
1034
+ " self.dropout = nn.Dropout(0.5)\n",
1035
+ "\n",
1036
+ " def forward(self, x):\n",
1037
+ " x = torch.relu(self.bn1(self.fc1(x)))\n",
1038
+ " x = self.dropout(x)\n",
1039
+ " x = torch.relu(self.bn2(self.fc2(x)))\n",
1040
+ " x = self.dropout(x)\n",
1041
+ " x = torch.relu(self.bn3(self.fc3(x)))\n",
1042
+ " x = self.dropout(x)\n",
1043
+ " x = self.output(x)\n",
1044
+ " return x\n",
1045
+ "\n",
1046
+ "# Load the trained model\n",
1047
+ "model_path = '/content/improved_model.pth'\n",
1048
+ "num_unique_titles = 4855 # Update this to match your dataset\n",
1049
+ "\n",
1050
+ "model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles) # Adjust input size accordingly\n",
1051
+ "model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))\n",
1052
+ "model.eval()\n",
1053
+ "\n",
1054
+ "# Load the label encoders and scaler\n",
1055
+ "label_encoders_path = '/content/new_label_encoders.joblib'\n",
1056
+ "scaler_path = '/content/new_scaler.joblib'\n",
1057
+ "\n",
1058
+ "label_encoders = load(label_encoders_path)\n",
1059
+ "scaler = load(scaler_path)\n",
1060
+ "\n",
1061
+ "# Create a mapping from encoded indices to actual song titles\n",
1062
+ "index_to_song_title = {index: title for index, title in enumerate(label_encoders['title'].classes_)}\n",
1063
+ "\n",
1064
+ "def encode_input(tags, artist_name):\n",
1065
+ " tags = tags.strip().replace('\\n', '')\n",
1066
+ " artist_name = artist_name.strip().replace('\\n', '')\n",
1067
+ "\n",
1068
+ " try:\n",
1069
+ " encoded_tags = label_encoders['tags'].transform([tags])[0]\n",
1070
+ " except ValueError:\n",
1071
+ " encoded_tags = label_encoders['tags'].transform(['unknown'])[0]\n",
1072
+ "\n",
1073
+ " try:\n",
1074
+ " encoded_artist = label_encoders['artist_name'].transform([artist_name])[0]\n",
1075
+ " except ValueError:\n",
1076
+ " encoded_artist = label_encoders['artist_name'].transform(['unknown'])[0]\n",
1077
+ "\n",
1078
+ " return [encoded_tags, encoded_artist]\n",
1079
+ "\n",
1080
+ "def recommend_songs(tags, artist_name):\n",
1081
+ " encoded_input = encode_input(tags, artist_name)\n",
1082
+ " input_tensor = torch.tensor([encoded_input]).float()\n",
1083
+ "\n",
1084
+ " with torch.no_grad():\n",
1085
+ " output = model(input_tensor)\n",
1086
+ "\n",
1087
+ " recommendations_indices = torch.topk(output, 5).indices.squeeze().tolist()\n",
1088
+ " recommendations = [index_to_song_title.get(idx, \"Unknown song\") for idx in recommendations_indices]\n",
1089
+ "\n",
1090
+ " return recommendations\n",
1091
+ "\n",
1092
+ "# Test the recommendation function\n",
1093
+ "tags = \"rock\"\n",
1094
+ "artist_name = \"The Beatles\"\n",
1095
+ "\n",
1096
+ "recommendations = recommend_songs(tags, artist_name)\n",
1097
+ "print(\"Recommendations:\", recommendations)\n"
1098
+ ],
1099
+ "metadata": {
1100
+ "colab": {
1101
+ "base_uri": "https://localhost:8080/"
1102
+ },
1103
+ "id": "KwqV-HnCOvtz",
1104
+ "outputId": "d412ce92-3ab8-4f3d-df83-22ef9e857203"
1105
+ },
1106
+ "execution_count": null,
1107
+ "outputs": [
1108
+ {
1109
+ "output_type": "stream",
1110
+ "name": "stdout",
1111
+ "text": [
1112
+ "Recommendations: ['Betrayal Is A Symptom', 'The Earth Will Shake', 'Saturday', 'Firehouse Rock', 'Breathe Easy']\n"
1113
+ ]
1114
+ }
1115
+ ]
1116
+ },
1117
+ {
1118
+ "cell_type": "code",
1119
+ "source": [
1120
+ "import torch\n",
1121
+ "from joblib import load\n",
1122
+ "\n",
1123
+ "# Define the same neural network model\n",
1124
+ "class ImprovedSongRecommender(nn.Module):\n",
1125
+ " def __init__(self, input_size, num_titles):\n",
1126
+ " super(ImprovedSongRecommender, self).__init__()\n",
1127
+ " self.fc1 = nn.Linear(input_size, 128)\n",
1128
+ " self.bn1 = nn.BatchNorm1d(128)\n",
1129
+ " self.fc2 = nn.Linear(128, 256)\n",
1130
+ " self.bn2 = nn.BatchNorm1d(256)\n",
1131
+ " self.fc3 = nn.Linear(256, 128)\n",
1132
+ " self.bn3 = nn.BatchNorm1d(128)\n",
1133
+ " self.output = nn.Linear(128, num_titles)\n",
1134
+ " self.dropout = nn.Dropout(0.5)\n",
1135
+ "\n",
1136
+ " def forward(self, x):\n",
1137
+ " x = torch.relu(self.bn1(self.fc1(x)))\n",
1138
+ " x = self.dropout(x)\n",
1139
+ " x = torch.relu(self.bn2(self.fc2(x)))\n",
1140
+ " x = self.dropout(x)\n",
1141
+ " x = torch.relu(self.bn3(self.fc3(x)))\n",
1142
+ " x = self.dropout(x)\n",
1143
+ " x = self.output(x)\n",
1144
+ " return x\n",
1145
+ "\n",
1146
+ "# Load the trained model\n",
1147
+ "model_path = '/content/improved_model.pth'\n",
1148
+ "num_unique_titles = 4855 # Update this to match your dataset\n",
1149
+ "\n",
1150
+ "model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles) # Adjust input size accordingly\n",
1151
+ "model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))\n",
1152
+ "model.eval()\n",
1153
+ "\n",
1154
+ "# Load the label encoders and scaler\n",
1155
+ "label_encoders_path = '/content/new_label_encoders.joblib'\n",
1156
+ "scaler_path = '/content/new_scaler.joblib'\n",
1157
+ "\n",
1158
+ "label_encoders = load(label_encoders_path)\n",
1159
+ "scaler = load(scaler_path)\n",
1160
+ "\n",
1161
+ "# Create a mapping from encoded indices to actual song titles\n",
1162
+ "index_to_song_title = {index: title for index, title in enumerate(label_encoders['title'].classes_)}\n",
1163
+ "\n",
1164
+ "def encode_input(tags, artist_name):\n",
1165
+ " tags = tags.strip().replace('\\n', '')\n",
1166
+ " artist_name = artist_name.strip().replace('\\n', '')\n",
1167
+ "\n",
1168
+ " try:\n",
1169
+ " encoded_tags = label_encoders['tags'].transform([tags])[0]\n",
1170
+ " except ValueError:\n",
1171
+ " encoded_tags = label_encoders['tags'].transform(['unknown'])[0]\n",
1172
+ "\n",
1173
+ " try:\n",
1174
+ " encoded_artist = label_encoders['artist_name'].transform([artist_name])[0]\n",
1175
+ " except ValueError:\n",
1176
+ " encoded_artist = label_encoders['artist_name'].transform(['unknown'])[0]\n",
1177
+ "\n",
1178
+ " return [encoded_tags, encoded_artist]\n",
1179
+ "\n",
1180
+ "def recommend_songs(tags, artist_name):\n",
1181
+ " encoded_input = encode_input(tags, artist_name)\n",
1182
+ " input_tensor = torch.tensor([encoded_input]).float()\n",
1183
+ "\n",
1184
+ " with torch.no_grad():\n",
1185
+ " output = model(input_tensor)\n",
1186
+ "\n",
1187
+ " recommendations_indices = torch.topk(output, 5).indices.squeeze().tolist()\n",
1188
+ " recommendations = [index_to_song_title.get(idx, \"Unknown song\") for idx in recommendations_indices]\n",
1189
+ "\n",
1190
+ " return recommendations\n",
1191
+ "\n",
1192
+ "# Test the recommendation function with new inputs\n",
1193
+ "tags = \"pop\"\n",
1194
+ "artist_name = \"Adele\"\n",
1195
+ "\n",
1196
+ "recommendations = recommend_songs(tags, artist_name)\n",
1197
+ "print(\"Recommendations:\", recommendations)\n",
1198
+ "\n",
1199
+ "# Test with another set of inputs\n",
1200
+ "tags = \"jazz\"\n",
1201
+ "artist_name = \"Miles Davis\"\n",
1202
+ "\n",
1203
+ "recommendations = recommend_songs(tags, artist_name)\n",
1204
+ "print(\"Recommendations:\", recommendations)\n"
1205
+ ],
1206
+ "metadata": {
1207
+ "colab": {
1208
+ "base_uri": "https://localhost:8080/"
1209
+ },
1210
+ "id": "3HzLKv5mPxOv",
1211
+ "outputId": "62b37d04-4857-44fb-b5c4-8ead55db9b1a"
1212
+ },
1213
+ "execution_count": null,
1214
+ "outputs": [
1215
+ {
1216
+ "output_type": "stream",
1217
+ "name": "stdout",
1218
+ "text": [
1219
+ "Recommendations: ['Betrayal Is A Symptom', 'Carnival (from \"Black Orpheus\")', 'Saturday', 'The Earth Will Shake', 'Start!']\n",
1220
+ "Recommendations: ['Old Friends', 'Betrayal Is A Symptom', 'Between Love & Hate', 'Carnival (from \"Black Orpheus\")', 'Satin Doll']\n"
1221
+ ]
1222
+ }
1223
+ ]
1224
+ }
1225
+ ]
1226
+ }