File size: 4,522 Bytes
9dd7d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **Evaluating the Recommendation Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/mocha/miniconda3/envs/mamba/envs/neurobytes_music_recommender/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import gradio as gr\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from joblib import load\n",
    "import sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the same neural network model\n",
    "class ImprovedSongRecommender(nn.Module):\n",
    "    def __init__(self, input_size, num_titles):\n",
    "        super(ImprovedSongRecommender, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_size, 128)\n",
    "        self.bn1 = nn.BatchNorm1d(128)\n",
    "        self.fc2 = nn.Linear(128, 256)\n",
    "        self.bn2 = nn.BatchNorm1d(256)\n",
    "        self.fc3 = nn.Linear(256, 128)\n",
    "        self.bn3 = nn.BatchNorm1d(128)\n",
    "        self.output = nn.Linear(128, num_titles)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.bn1(self.fc1(x)))\n",
    "        x = self.dropout(x)\n",
    "        x = torch.relu(self.bn2(self.fc2(x)))\n",
    "        x = self.dropout(x)\n",
    "        x = torch.relu(self.bn3(self.fc3(x)))\n",
    "        x = self.dropout(x)\n",
    "        x = self.output(x)\n",
    "        return x\n",
    "\n",
    "# Load the trained model\n",
    "model_path = \"../models/improved_model.pth\"\n",
    "num_unique_titles = 4855  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles)  \n",
    "model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/mocha/miniconda3/envs/mamba/envs/neurobytes_music_recommender/lib/python3.8/site-packages/sklearn/base.py:348: InconsistentVersionWarning: Trying to unpickle estimator LabelEncoder from version 1.2.2 when using version 1.3.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
      "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
      "  warnings.warn(\n",
      "/Users/mocha/miniconda3/envs/mamba/envs/neurobytes_music_recommender/lib/python3.8/site-packages/sklearn/base.py:348: InconsistentVersionWarning: Trying to unpickle estimator MinMaxScaler from version 1.2.2 when using version 1.3.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
      "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# Load the label encoders and scaler\n",
    "label_encoders_path = \"data/new_label_encoders.joblib\"\n",
    "scaler_path = \"data/new_scaler.joblib\"\n",
    "\n",
    "label_encoders = load(label_encoders_path)\n",
    "scaler = load(scaler_path)\n",
    "\n",
    "# Create a mapping from encoded indices to actual song titles\n",
    "index_to_song_title = {index: title for index, title in enumerate(label_encoders['title'].classes_)}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}