{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-12-06T19:54:24.990141Z", "start_time": "2024-12-06T19:53:17.183491Z" } }, "source": [ "!pip install geopy > delete.txt\n", "!pip install datasets > delete.txt\n", "!pip install torch torchvision datasets > delete.txt\n", "!pip install huggingface_hub > delete.txt\n", "!pip install pyhocon > delete.txt\n", "!pip install transformers > delete.txt\n", "!rm delete.txt" ], "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "'rm' is not recognized as an internal or external command,\n", "operable program or batch file.\n" ] } ], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2024-12-06T19:56:26.136466Z", "start_time": "2024-12-06T19:54:38.679955Z" } }, "cell_type": "code", "source": "!huggingface-cli login", "id": "b0a77c981c32a0c8", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "^C\n" ] } ], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2024-12-06T19:57:30.983629Z", "start_time": "2024-12-06T19:57:29.451887Z" } }, "cell_type": "code", "source": [ "from datasets import load_dataset\n", "\n", "dataset_train = load_dataset(\"CISProject/FOX_NBC\", split=\"train\")\n", "dataset_test = load_dataset(\"path/to/test\", split=\"test\")" ], "id": "a4aa3b759defc904", "outputs": [], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2024-12-06T19:58:41.568459Z", "start_time": "2024-12-06T19:58:41.445848Z" } }, "cell_type": "code", "source": [ "import numpy as np\n", "import torch\n", "from transformers import BertTokenizer\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "\n", "def positional_encoding(seq_len, d_model):\n", " pos_enc = np.zeros((seq_len, d_model))\n", " for pos in range(seq_len):\n", " for i in range(0, d_model, 2):\n", " pos_enc[pos, i] = np.sin(pos / (10000 ** ((2 * i) / d_model)))\n", " if i + 1 < d_model:\n", " pos_enc[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))\n", " return torch.tensor(pos_enc, dtype=torch.float)\n", "\n", "def preprocess_data(data, mode=\"train\", tfidf_vectorizer=None, max_tfidf_features=4096, max_seq_length=128, num_proc=4):\n", " tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", "\n", " # Initialize TF-IDF vectorizer if not provided\n", " if tfidf_vectorizer is None and mode == \"train\":\n", " tfidf_vectorizer = TfidfVectorizer(max_features=max_tfidf_features)\n", "\n", " # Fit TF-IDF only in train mode\n", " if mode == \"train\":\n", " tfidf_vectorizer.fit(data[\"title\"])\n", " print(\"TF-IDF vectorizer fitted on training data.\")\n", "\n", " def process_batch(batch):\n", " headlines = batch[\"title\"]\n", " agencies = batch[\"news\"]\n", "\n", " # TF-IDF transformation (batch-wise)\n", " if mode == \"train\" or tfidf_vectorizer is not None:\n", " freq_inputs = tfidf_vectorizer.transform(headlines).toarray()\n", " else:\n", " raise ValueError(\"TF-IDF vectorizer must be provided in test mode.\")\n", "\n", " # Tokenization (batch-wise)\n", " tokenized = tokenizer(\n", " headlines,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=max_seq_length,\n", " return_tensors=\"pt\"\n", " )\n", "\n", " # Stack input_ids and attention_mask along a new dimension\n", " input_ids = tokenized[\"input_ids\"]\n", " attention_mask = tokenized[\"attention_mask\"]\n", "\n", " # Ensure consistent stacking: (batch_size, 2, seq_len)\n", " seq_inputs = torch.stack([input_ids, attention_mask], dim=1)\n", "\n", " # Positional encoding\n", " pos_inputs = positional_encoding(max_seq_length, 512).unsqueeze(0).expand(len(headlines), -1, -1)\n", "\n", " # Labels\n", " labels = [1.0 if agency == \"fox\" else 0.0 for agency in agencies]\n", "\n", " return {\n", " \"freq_inputs\": torch.tensor(freq_inputs),\n", " \"seq_inputs\": seq_inputs,\n", " \"pos_inputs\": pos_inputs,\n", " \"labels\": torch.tensor(labels),\n", " }\n", "\n", " # Use `map` with batching and parallelism\n", " processed_data = data.map(\n", " process_batch,\n", " batched=True,\n", " batch_size=32,\n", " num_proc=num_proc\n", " )\n", "\n", " return processed_data, tfidf_vectorizer" ], "id": "ce6e6b982e22e9fe", "outputs": [ { "ename": "ValueError", "evalue": "numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject", "output_type": "error", "traceback": [ "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)", "Cell \u001B[1;32mIn[12], line 4\u001B[0m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtransformers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m BertTokenizer\n\u001B[1;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01msklearn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfeature_extraction\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mtext\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m TfidfVectorizer\n\u001B[0;32m 6\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mpositional_encoding\u001B[39m(seq_len, d_model):\n\u001B[0;32m 7\u001B[0m pos_enc \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros((seq_len, d_model))\n", "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\__init__.py:84\u001B[0m\n\u001B[0;32m 70\u001B[0m \u001B[38;5;66;03m# We are not importing the rest of scikit-learn during the build\u001B[39;00m\n\u001B[0;32m 71\u001B[0m \u001B[38;5;66;03m# process, as it may not be compiled yet\u001B[39;00m\n\u001B[0;32m 72\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 78\u001B[0m \u001B[38;5;66;03m# later is linked to the OpenMP runtime to make it possible to introspect\u001B[39;00m\n\u001B[0;32m 79\u001B[0m \u001B[38;5;66;03m# it and importing it first would fail if the OpenMP dll cannot be found.\u001B[39;00m\n\u001B[0;32m 80\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[0;32m 81\u001B[0m __check_build, \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m 82\u001B[0m _distributor_init, \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[0;32m 83\u001B[0m )\n\u001B[1;32m---> 84\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbase\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m clone\n\u001B[0;32m 85\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_show_versions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m show_versions\n\u001B[0;32m 87\u001B[0m __all__ \u001B[38;5;241m=\u001B[39m [\n\u001B[0;32m 88\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcalibration\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 89\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcluster\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 130\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mshow_versions\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 131\u001B[0m ]\n", "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\base.py:19\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexceptions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m InconsistentVersionWarning\n\u001B[1;32m---> 19\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _HTMLDocumentationLinkMixin, estimator_html_repr\n\u001B[0;32m 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_metadata_requests\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _MetadataRequester, _routing_enabled\n\u001B[0;32m 21\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m validate_parameter_constraints\n", "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\__init__.py:11\u001B[0m\n\u001B[0;32m 9\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _joblib, metadata_routing\n\u001B[0;32m 10\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_bunch\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Bunch\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_chunking\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m gen_batches, gen_even_slices\n\u001B[0;32m 12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_estimator_html_repr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m estimator_html_repr\n\u001B[0;32m 14\u001B[0m \u001B[38;5;66;03m# Make _safe_indexing importable from here for backward compat as this particular\u001B[39;00m\n\u001B[0;32m 15\u001B[0m \u001B[38;5;66;03m# helper is considered semi-private and typically very useful for third-party\u001B[39;00m\n\u001B[0;32m 16\u001B[0m \u001B[38;5;66;03m# libraries that want to comply with scikit-learn's estimator API. In particular,\u001B[39;00m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;66;03m# _safe_indexing was included in our public API documentation despite the leading\u001B[39;00m\n\u001B[0;32m 18\u001B[0m \u001B[38;5;66;03m# `_` in its name.\u001B[39;00m\n", "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_chunking.py:8\u001B[0m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 7\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_config\n\u001B[1;32m----> 8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_param_validation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Interval, validate_params\n\u001B[0;32m 11\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mchunk_generator\u001B[39m(gen, chunksize):\n\u001B[0;32m 12\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Chunk generator, ``gen`` into lists of length ``chunksize``. The last\u001B[39;00m\n\u001B[0;32m 13\u001B[0m \u001B[38;5;124;03m chunk may have a length less than ``chunksize``.\"\"\"\u001B[39;00m\n", "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\sklearn\\utils\\_param_validation.py:11\u001B[0m\n\u001B[0;32m 8\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mnumbers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Integral, Real\n\u001B[0;32m 10\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mscipy\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msparse\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m csr_matrix, issparse\n\u001B[0;32m 13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_config\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config_context, get_config\n\u001B[0;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mvalidation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _is_arraylike_not_scalar\n", "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\__init__.py:297\u001B[0m\n\u001B[0;32m 295\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csr\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 296\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_csc\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[1;32m--> 297\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_lil\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 298\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_dok\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n\u001B[0;32m 299\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_coo\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n", "File \u001B[1;32m~\\anaconda3\\envs\\CIS5190eval\\lib\\site-packages\\scipy\\sparse\\_lil.py:17\u001B[0m\n\u001B[0;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_index\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m IndexMixin, INT_TYPES, _broadcast_arrays\n\u001B[0;32m 15\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_sputils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (getdtype, isshape, isscalarlike, upcast_scalar,\n\u001B[0;32m 16\u001B[0m check_shape, check_reshape_kwargs)\n\u001B[1;32m---> 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _csparsetools\n\u001B[0;32m 20\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01m_lil_base\u001B[39;00m(_spbase, IndexMixin):\n\u001B[0;32m 21\u001B[0m _format \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlil\u001B[39m\u001B[38;5;124m'\u001B[39m\n", "File \u001B[1;32mscipy\\\\sparse\\\\_csparsetools.pyx:1\u001B[0m, in \u001B[0;36minit _csparsetools\u001B[1;34m()\u001B[0m\n", "\u001B[1;31mValueError\u001B[0m: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject" ] } ], "execution_count": 12 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "dataset_train, tfidf_vectorizer = preprocess_data(\n", " data=dataset_train,\n", " mode=\"train\",\n", " max_tfidf_features=8192,\n", " max_seq_length=128\n", ")\n", "\n", "dataset_test, _ = preprocess_data(\n", " data=dataset_test,\n", " mode=\"test\",\n", " tfidf_vectorizer=tfidf_vectorizer,\n", " max_tfidf_features=8192,\n", " max_seq_length=128\n", ")" ], "id": "b605d3b4f5ff547a" }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "# Load model directly\n", "from transformers import AutoModel\n", "model = AutoModel.from_pretrained(\"CISProject/News-Headline-Classifier-Notebook\")" ], "id": "b20d11caa1d25445" }, { "metadata": { "ExecuteTime": { "end_time": "2024-12-06T19:53:05.824524Z", "start_time": "2024-12-06T19:53:05.550141Z" } }, "cell_type": "code", "source": [ "from torch.utils.data import DataLoader\n", "\n", "# Define a collate function to handle the batched data\n", "def collate_fn(batch):\n", " freq_inputs = torch.stack([torch.tensor(item[\"freq_inputs\"]) for item in batch])\n", " seq_inputs = torch.stack([torch.tensor(item[\"seq_inputs\"]) for item in batch])\n", " pos_inputs = torch.stack([torch.tensor(item[\"pos_inputs\"]) for item in batch])\n", " labels = torch.tensor([torch.tensor(item[\"labels\"]) for item in batch])\n", " return {\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs}, labels\n", "\n", "train_loader = DataLoader(dataset_train, batch_size=config.train[\"batch_size\"], shuffle=True,collate_fn=collate_fn)\n", "test_loader = DataLoader(dataset_test, batch_size=config.train[\"batch_size\"], shuffle=False,collate_fn=collate_fn)\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)\n", "\n", "criterion = torch.nn.BCEWithLogitsLoss()\n", "\n", "def evaluate_model(model, val_loader, criterion, device=\"cuda\"):\n", " model.eval()\n", " val_loss = 0.0\n", " correct = 0\n", " total = 0\n", "\n", " with torch.no_grad():\n", " for batch_inputs, labels in tqdm(val_loader, desc=\"Testing\", leave=False):\n", " freq_inputs = batch_inputs[\"freq_inputs\"].to(device)\n", " seq_inputs = batch_inputs[\"seq_inputs\"].to(device)\n", " pos_inputs = batch_inputs[\"pos_inputs\"].to(device)\n", " labels = labels[:,None].to(device)\n", "\n", " preds = model({\"freq_inputs\": freq_inputs, \"seq_inputs\": seq_inputs, \"pos_inputs\": pos_inputs})\n", " loss = criterion(preds, labels)\n", "\n", " val_loss += loss.item()\n", " total += labels.size(0)\n", " correct += ((torch.sigmoid(preds) > 0.5).float() == labels).sum().item()\n", "\n", " print(f\"Test Loss: {val_loss / total:.4f}\")\n", " print(f\"Test Accuracy: {correct / total:.4f}\")\n", "\n", "\n", "evaluate_model(model, test_loader, criterion)\n", "# Save the final model in Hugging Face format\n", "\n" ], "id": "1d23cedfe1d79660", "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'torch'", "output_type": "error", "traceback": [ "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", "\u001B[1;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)", "Cell \u001B[1;32mIn[1], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m DataLoader\n\u001B[0;32m 3\u001B[0m \u001B[38;5;66;03m# Define a collate function to handle the batched data\u001B[39;00m\n\u001B[0;32m 4\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mcollate_fn\u001B[39m(batch):\n", "\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'torch'" ] } ], "execution_count": 1 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "", "id": "549f3e0a004e80ab" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }