Spaces:
Runtime error
Runtime error
File size: 4,522 Bytes
9dd7d9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# **Evaluating the Recommendation Model**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/mocha/miniconda3/envs/mamba/envs/neurobytes_music_recommender/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import gradio as gr\n",
"import torch\n",
"import torch.nn as nn\n",
"from joblib import load\n",
"import sklearn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Define the same neural network model\n",
"class ImprovedSongRecommender(nn.Module):\n",
" def __init__(self, input_size, num_titles):\n",
" super(ImprovedSongRecommender, self).__init__()\n",
" self.fc1 = nn.Linear(input_size, 128)\n",
" self.bn1 = nn.BatchNorm1d(128)\n",
" self.fc2 = nn.Linear(128, 256)\n",
" self.bn2 = nn.BatchNorm1d(256)\n",
" self.fc3 = nn.Linear(256, 128)\n",
" self.bn3 = nn.BatchNorm1d(128)\n",
" self.output = nn.Linear(128, num_titles)\n",
" self.dropout = nn.Dropout(0.5)\n",
"\n",
" def forward(self, x):\n",
" x = torch.relu(self.bn1(self.fc1(x)))\n",
" x = self.dropout(x)\n",
" x = torch.relu(self.bn2(self.fc2(x)))\n",
" x = self.dropout(x)\n",
" x = torch.relu(self.bn3(self.fc3(x)))\n",
" x = self.dropout(x)\n",
" x = self.output(x)\n",
" return x\n",
"\n",
"# Load the trained model\n",
"model_path = \"../models/improved_model.pth\"\n",
"num_unique_titles = 4855 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles) \n",
"model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/mocha/miniconda3/envs/mamba/envs/neurobytes_music_recommender/lib/python3.8/site-packages/sklearn/base.py:348: InconsistentVersionWarning: Trying to unpickle estimator LabelEncoder from version 1.2.2 when using version 1.3.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
"https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
" warnings.warn(\n",
"/Users/mocha/miniconda3/envs/mamba/envs/neurobytes_music_recommender/lib/python3.8/site-packages/sklearn/base.py:348: InconsistentVersionWarning: Trying to unpickle estimator MinMaxScaler from version 1.2.2 when using version 1.3.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
"https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
" warnings.warn(\n"
]
}
],
"source": [
"# Load the label encoders and scaler\n",
"label_encoders_path = \"data/new_label_encoders.joblib\"\n",
"scaler_path = \"data/new_scaler.joblib\"\n",
"\n",
"label_encoders = load(label_encoders_path)\n",
"scaler = load(scaler_path)\n",
"\n",
"# Create a mapping from encoded indices to actual song titles\n",
"index_to_song_title = {index: title for index, title in enumerate(label_encoders['title'].classes_)}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|