Ryan Kim commited on
Commit
375a093
β€’
1 Parent(s): 6bec35d

adding code for training

Browse files
data/train.json ADDED
Binary file (58.5 MB). View file
 
data/val.json ADDED
The diff for this file is too large to render. See raw diff
 
logs/1681910017.7615924/events.out.tfevents.1681910017.025fe27979cb.15711.1 ADDED
Binary file (5.81 kB). View file
 
logs/events.out.tfevents.1681910017.025fe27979cb.15711.0 ADDED
Binary file (3.81 kB). View file
 
src/patent_train.ipynb CHANGED
@@ -11,78 +11,51 @@
11
  "cell_type": "markdown",
12
  "metadata": {},
13
  "source": [
14
- "## Importing Packages"
 
 
15
  ]
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": 3,
20
  "metadata": {},
21
  "outputs": [
22
  {
23
  "name": "stdout",
24
  "output_type": "stream",
25
  "text": [
26
- "Collecting datasets\n",
27
- " Downloading datasets-2.11.0-py3-none-any.whl (468 kB)\n",
28
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.7/468.7 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
29
- "\u001b[?25hRequirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (4.64.1)\n",
30
- "Requirement already satisfied: pyarrow>=8.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (9.0.0)\n",
31
  "Requirement already satisfied: fsspec[http]>=2021.11.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (2022.8.2)\n",
32
- "Collecting aiohttp\n",
33
- " Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.0 MB)\n",
34
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
35
- "\u001b[?25hCollecting huggingface-hub<1.0.0,>=0.11.0\n",
36
- " Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)\n",
37
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.8/199.8 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
38
- "\u001b[?25hRequirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets) (1.5.0)\n",
39
- "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets) (21.3)\n",
40
- "Collecting responses<0.19\n",
41
- " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
42
- "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.3.5.1)\n",
43
- "Collecting xxhash\n",
44
- " Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (242 kB)\n",
45
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m242.7/242.7 kB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
46
- "\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (2.28.1)\n",
47
- "Collecting multiprocess\n",
48
- " Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)\n",
49
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m11.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
50
- "\u001b[?25hRequirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (6.0)\n",
51
  "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.23.3)\n",
52
- "Collecting multidict<7.0,>=4.5\n",
53
- " Downloading multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (116 kB)\n",
54
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.6/116.6 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
55
- "\u001b[?25hCollecting yarl<2.0,>=1.0\n",
56
- " Downloading yarl-1.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (257 kB)\n",
57
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m257.3/257.3 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
58
- "\u001b[?25hCollecting frozenlist>=1.1.1\n",
59
- " Downloading frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (148 kB)\n",
60
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.1/148.1 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
61
- "\u001b[?25hRequirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (2.1.1)\n",
 
 
62
  "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (22.1.0)\n",
63
- "Collecting aiosignal>=1.1.2\n",
64
- " Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
65
- "Collecting async-timeout<5.0,>=4.0.0a3\n",
66
- " Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
67
- "Collecting filelock\n",
68
- " Downloading filelock-3.10.7-py3-none-any.whl (10 kB)\n",
69
  "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.4.0)\n",
 
70
  "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->datasets) (3.0.9)\n",
 
71
  "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.4)\n",
72
  "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (1.26.11)\n",
73
- "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2022.9.24)\n",
74
- "Collecting dill<0.3.7,>=0.3.0\n",
75
- " Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n",
76
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
77
- "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
78
  "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2022.4)\n",
79
- "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
80
- "Installing collected packages: xxhash, multidict, frozenlist, filelock, dill, async-timeout, yarl, responses, multiprocess, huggingface-hub, aiosignal, aiohttp, datasets\n",
81
- " Attempting uninstall: dill\n",
82
- " Found existing installation: dill 0.3.5.1\n",
83
- " Uninstalling dill-0.3.5.1:\n",
84
- " Successfully uninstalled dill-0.3.5.1\n",
85
- "Successfully installed aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.11.0 dill-0.3.6 filelock-3.10.7 frozenlist-1.3.3 huggingface-hub-0.13.3 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 xxhash-3.2.0 yarl-1.8.2\n"
86
  ]
87
  }
88
  ],
@@ -92,14 +65,13 @@
92
  },
93
  {
94
  "cell_type": "code",
95
- "execution_count": 3,
96
  "metadata": {},
97
  "outputs": [],
98
  "source": [
99
  "from datasets import load_dataset\n",
100
  "import pandas as pd\n",
101
- "import numpy as np\n",
102
- "import matplotlib.pyplot as plt"
103
  ]
104
  },
105
  {
@@ -113,25 +85,25 @@
113
  "cell_type": "markdown",
114
  "metadata": {},
115
  "source": [
116
- "We first need to extract the dataset. We filter only for those in January 2016."
117
  ]
118
  },
119
  {
120
  "cell_type": "code",
121
- "execution_count": 17,
122
  "metadata": {},
123
  "outputs": [
124
  {
125
  "name": "stderr",
126
  "output_type": "stream",
127
  "text": [
128
- "Found cached dataset hupd (/home/jovyan/.cache/huggingface/datasets/HUPD___hupd/sample-ba3b43e1cc5c9c76/0.0.0/6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142)\n"
129
  ]
130
  },
131
  {
132
  "data": {
133
  "application/vnd.jupyter.widget-view+json": {
134
- "model_id": "e9c97a02e1834189bcdd4c188ef555d7",
135
  "version_major": 2,
136
  "version_minor": 0
137
  },
@@ -164,7 +136,7 @@
164
  },
165
  {
166
  "cell_type": "code",
167
- "execution_count": 21,
168
  "metadata": {},
169
  "outputs": [
170
  {
@@ -197,7 +169,7 @@
197
  },
198
  {
199
  "cell_type": "code",
200
- "execution_count": 19,
201
  "metadata": {},
202
  "outputs": [],
203
  "source": [
@@ -214,7 +186,7 @@
214
  },
215
  {
216
  "cell_type": "code",
217
- "execution_count": 20,
218
  "metadata": {},
219
  "outputs": [
220
  {
@@ -555,7 +527,7 @@
555
  "[16153 rows x 14 columns]"
556
  ]
557
  },
558
- "execution_count": 20,
559
  "metadata": {},
560
  "output_type": "execute_result"
561
  }
@@ -587,7 +559,7 @@
587
  },
588
  {
589
  "cell_type": "code",
590
- "execution_count": 28,
591
  "metadata": {},
592
  "outputs": [],
593
  "source": [
@@ -597,7 +569,7 @@
597
  },
598
  {
599
  "cell_type": "code",
600
- "execution_count": 29,
601
  "metadata": {},
602
  "outputs": [],
603
  "source": [
@@ -609,7 +581,7 @@
609
  },
610
  {
611
  "cell_type": "code",
612
- "execution_count": 30,
613
  "metadata": {},
614
  "outputs": [
615
  {
@@ -740,7 +712,7 @@
740
  "[8719 rows x 3 columns]"
741
  ]
742
  },
743
- "execution_count": 30,
744
  "metadata": {},
745
  "output_type": "execute_result"
746
  }
@@ -751,7 +723,7 @@
751
  },
752
  {
753
  "cell_type": "code",
754
- "execution_count": 31,
755
  "metadata": {},
756
  "outputs": [],
757
  "source": [
@@ -763,7 +735,7 @@
763
  },
764
  {
765
  "cell_type": "code",
766
- "execution_count": 32,
767
  "metadata": {},
768
  "outputs": [
769
  {
@@ -894,7 +866,7 @@
894
  "[4888 rows x 3 columns]"
895
  ]
896
  },
897
- "execution_count": 32,
898
  "metadata": {},
899
  "output_type": "execute_result"
900
  }
@@ -912,7 +884,7 @@
912
  },
913
  {
914
  "cell_type": "code",
915
- "execution_count": 33,
916
  "metadata": {},
917
  "outputs": [],
918
  "source": [
@@ -921,7 +893,7 @@
921
  },
922
  {
923
  "cell_type": "code",
924
- "execution_count": 34,
925
  "metadata": {},
926
  "outputs": [],
927
  "source": [
@@ -931,7 +903,7 @@
931
  },
932
  {
933
  "cell_type": "code",
934
- "execution_count": 35,
935
  "metadata": {},
936
  "outputs": [
937
  {
@@ -1062,7 +1034,7 @@
1062
  "[8719 rows x 3 columns]"
1063
  ]
1064
  },
1065
- "execution_count": 35,
1066
  "metadata": {},
1067
  "output_type": "execute_result"
1068
  }
@@ -1073,7 +1045,7 @@
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
- "execution_count": 36,
1077
  "metadata": {},
1078
  "outputs": [
1079
  {
@@ -1204,7 +1176,7 @@
1204
  "[4888 rows x 3 columns]"
1205
  ]
1206
  },
1207
- "execution_count": 36,
1208
  "metadata": {},
1209
  "output_type": "execute_result"
1210
  }
@@ -1213,6 +1185,432 @@
1213
  "valDF2"
1214
  ]
1215
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
  {
1217
  "cell_type": "code",
1218
  "execution_count": null,
 
11
  "cell_type": "markdown",
12
  "metadata": {},
13
  "source": [
14
+ "## Importing Packages\n",
15
+ "\n",
16
+ "We first need to import the actual USPTO dataset."
17
  ]
18
  },
19
  {
20
  "cell_type": "code",
21
+ "execution_count": 1,
22
  "metadata": {},
23
  "outputs": [
24
  {
25
  "name": "stdout",
26
  "output_type": "stream",
27
  "text": [
28
+ "Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (2.11.0)\n",
 
 
 
 
29
  "Requirement already satisfied: fsspec[http]>=2021.11.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (2022.8.2)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.23.3)\n",
31
+ "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (4.64.1)\n",
32
+ "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (2.28.1)\n",
33
+ "Requirement already satisfied: pyarrow>=8.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (9.0.0)\n",
34
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (6.0)\n",
35
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets) (1.5.0)\n",
36
+ "Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.13.4)\n",
37
+ "Requirement already satisfied: responses<0.19 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.18.0)\n",
38
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets) (3.2.0)\n",
39
+ "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.3.6)\n",
40
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets) (3.8.4)\n",
41
+ "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets) (21.3)\n",
42
+ "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets) (0.70.14)\n",
43
  "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (22.1.0)\n",
44
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n",
45
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.2)\n",
46
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4)\n",
47
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.8.2)\n",
48
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.3)\n",
49
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (2.1.1)\n",
50
  "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.4.0)\n",
51
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (3.12.0)\n",
52
  "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->datasets) (3.0.9)\n",
53
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2022.9.24)\n",
54
  "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.4)\n",
55
  "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (1.26.11)\n",
56
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
 
 
 
 
57
  "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2022.4)\n",
58
+ "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n"
 
 
 
 
 
 
59
  ]
60
  }
61
  ],
 
65
  },
66
  {
67
  "cell_type": "code",
68
+ "execution_count": 2,
69
  "metadata": {},
70
  "outputs": [],
71
  "source": [
72
  "from datasets import load_dataset\n",
73
  "import pandas as pd\n",
74
+ "import numpy as np"
 
75
  ]
76
  },
77
  {
 
85
  "cell_type": "markdown",
86
  "metadata": {},
87
  "source": [
88
+ "We need to extract the dataset. We filter only for those in January 2016."
89
  ]
90
  },
91
  {
92
  "cell_type": "code",
93
+ "execution_count": 3,
94
  "metadata": {},
95
  "outputs": [
96
  {
97
  "name": "stderr",
98
  "output_type": "stream",
99
  "text": [
100
+ "Found cached dataset hupd (/home/jovyan/.cache/huggingface/datasets/HUPD___hupd/sample-a4eeba92b4229e93/0.0.0/6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142)\n"
101
  ]
102
  },
103
  {
104
  "data": {
105
  "application/vnd.jupyter.widget-view+json": {
106
+ "model_id": "e39fd26828774c8e9d159a8b5d91c4f5",
107
  "version_major": 2,
108
  "version_minor": 0
109
  },
 
136
  },
137
  {
138
  "cell_type": "code",
139
+ "execution_count": 4,
140
  "metadata": {},
141
  "outputs": [
142
  {
 
169
  },
170
  {
171
  "cell_type": "code",
172
+ "execution_count": 5,
173
  "metadata": {},
174
  "outputs": [],
175
  "source": [
 
186
  },
187
  {
188
  "cell_type": "code",
189
+ "execution_count": 6,
190
  "metadata": {},
191
  "outputs": [
192
  {
 
527
  "[16153 rows x 14 columns]"
528
  ]
529
  },
530
+ "execution_count": 6,
531
  "metadata": {},
532
  "output_type": "execute_result"
533
  }
 
559
  },
560
  {
561
  "cell_type": "code",
562
+ "execution_count": 7,
563
  "metadata": {},
564
  "outputs": [],
565
  "source": [
 
569
  },
570
  {
571
  "cell_type": "code",
572
+ "execution_count": 8,
573
  "metadata": {},
574
  "outputs": [],
575
  "source": [
 
581
  },
582
  {
583
  "cell_type": "code",
584
+ "execution_count": 9,
585
  "metadata": {},
586
  "outputs": [
587
  {
 
712
  "[8719 rows x 3 columns]"
713
  ]
714
  },
715
+ "execution_count": 9,
716
  "metadata": {},
717
  "output_type": "execute_result"
718
  }
 
723
  },
724
  {
725
  "cell_type": "code",
726
+ "execution_count": 10,
727
  "metadata": {},
728
  "outputs": [],
729
  "source": [
 
735
  },
736
  {
737
  "cell_type": "code",
738
+ "execution_count": 11,
739
  "metadata": {},
740
  "outputs": [
741
  {
 
866
  "[4888 rows x 3 columns]"
867
  ]
868
  },
869
+ "execution_count": 11,
870
  "metadata": {},
871
  "output_type": "execute_result"
872
  }
 
884
  },
885
  {
886
  "cell_type": "code",
887
+ "execution_count": 12,
888
  "metadata": {},
889
  "outputs": [],
890
  "source": [
 
893
  },
894
  {
895
  "cell_type": "code",
896
+ "execution_count": 13,
897
  "metadata": {},
898
  "outputs": [],
899
  "source": [
 
903
  },
904
  {
905
  "cell_type": "code",
906
+ "execution_count": 14,
907
  "metadata": {},
908
  "outputs": [
909
  {
 
1034
  "[8719 rows x 3 columns]"
1035
  ]
1036
  },
1037
+ "execution_count": 14,
1038
  "metadata": {},
1039
  "output_type": "execute_result"
1040
  }
 
1045
  },
1046
  {
1047
  "cell_type": "code",
1048
+ "execution_count": 15,
1049
  "metadata": {},
1050
  "outputs": [
1051
  {
 
1176
  "[4888 rows x 3 columns]"
1177
  ]
1178
  },
1179
+ "execution_count": 15,
1180
  "metadata": {},
1181
  "output_type": "execute_result"
1182
  }
 
1185
  "valDF2"
1186
  ]
1187
  },
1188
+ {
1189
+ "cell_type": "markdown",
1190
+ "metadata": {},
1191
+ "source": [
1192
+ "We combine the `abstract` and `claims` columns into a single `text` column. We also re-label the `decision` column to `label`."
1193
+ ]
1194
+ },
1195
+ {
1196
+ "cell_type": "code",
1197
+ "execution_count": 16,
1198
+ "metadata": {},
1199
+ "outputs": [
1200
+ {
1201
+ "data": {
1202
+ "text/html": [
1203
+ "<div>\n",
1204
+ "<style scoped>\n",
1205
+ " .dataframe tbody tr th:only-of-type {\n",
1206
+ " vertical-align: middle;\n",
1207
+ " }\n",
1208
+ "\n",
1209
+ " .dataframe tbody tr th {\n",
1210
+ " vertical-align: top;\n",
1211
+ " }\n",
1212
+ "\n",
1213
+ " .dataframe thead th {\n",
1214
+ " text-align: right;\n",
1215
+ " }\n",
1216
+ "</style>\n",
1217
+ "<table border=\"1\" class=\"dataframe\">\n",
1218
+ " <thead>\n",
1219
+ " <tr style=\"text-align: right;\">\n",
1220
+ " <th></th>\n",
1221
+ " <th>label</th>\n",
1222
+ " <th>text</th>\n",
1223
+ " </tr>\n",
1224
+ " </thead>\n",
1225
+ " <tbody>\n",
1226
+ " <tr>\n",
1227
+ " <th>0</th>\n",
1228
+ " <td>1</td>\n",
1229
+ " <td>The present invention relates to passive optic...</td>\n",
1230
+ " </tr>\n",
1231
+ " <tr>\n",
1232
+ " <th>1</th>\n",
1233
+ " <td>1</td>\n",
1234
+ " <td>Embodiments of the invention provide a method ...</td>\n",
1235
+ " </tr>\n",
1236
+ " <tr>\n",
1237
+ " <th>3</th>\n",
1238
+ " <td>1</td>\n",
1239
+ " <td>A crystal growth furnace comprising a crucible...</td>\n",
1240
+ " </tr>\n",
1241
+ " <tr>\n",
1242
+ " <th>4</th>\n",
1243
+ " <td>0</td>\n",
1244
+ " <td>A shoe midsole is composed of a base plate (1)...</td>\n",
1245
+ " </tr>\n",
1246
+ " <tr>\n",
1247
+ " <th>5</th>\n",
1248
+ " <td>1</td>\n",
1249
+ " <td>A ratchet tool includes a shaft member, a hand...</td>\n",
1250
+ " </tr>\n",
1251
+ " <tr>\n",
1252
+ " <th>...</th>\n",
1253
+ " <td>...</td>\n",
1254
+ " <td>...</td>\n",
1255
+ " </tr>\n",
1256
+ " <tr>\n",
1257
+ " <th>16144</th>\n",
1258
+ " <td>1</td>\n",
1259
+ " <td>A wavelength tunable laser device, including: ...</td>\n",
1260
+ " </tr>\n",
1261
+ " <tr>\n",
1262
+ " <th>16145</th>\n",
1263
+ " <td>1</td>\n",
1264
+ " <td>In one aspect, a method for use in preparing a...</td>\n",
1265
+ " </tr>\n",
1266
+ " <tr>\n",
1267
+ " <th>16148</th>\n",
1268
+ " <td>1</td>\n",
1269
+ " <td>A robot hand controlling method executes calcu...</td>\n",
1270
+ " </tr>\n",
1271
+ " <tr>\n",
1272
+ " <th>16149</th>\n",
1273
+ " <td>0</td>\n",
1274
+ " <td>A fusion protein is disclosed. The fusion prot...</td>\n",
1275
+ " </tr>\n",
1276
+ " <tr>\n",
1277
+ " <th>16150</th>\n",
1278
+ " <td>0</td>\n",
1279
+ " <td>A pipe extraction tool that grips the inside o...</td>\n",
1280
+ " </tr>\n",
1281
+ " </tbody>\n",
1282
+ "</table>\n",
1283
+ "<p>8719 rows Γ— 2 columns</p>\n",
1284
+ "</div>"
1285
+ ],
1286
+ "text/plain": [
1287
+ " label text\n",
1288
+ "0 1 The present invention relates to passive optic...\n",
1289
+ "1 1 Embodiments of the invention provide a method ...\n",
1290
+ "3 1 A crystal growth furnace comprising a crucible...\n",
1291
+ "4 0 A shoe midsole is composed of a base plate (1)...\n",
1292
+ "5 1 A ratchet tool includes a shaft member, a hand...\n",
1293
+ "... ... ...\n",
1294
+ "16144 1 A wavelength tunable laser device, including: ...\n",
1295
+ "16145 1 In one aspect, a method for use in preparing a...\n",
1296
+ "16148 1 A robot hand controlling method executes calcu...\n",
1297
+ "16149 0 A fusion protein is disclosed. The fusion prot...\n",
1298
+ "16150 0 A pipe extraction tool that grips the inside o...\n",
1299
+ "\n",
1300
+ "[8719 rows x 2 columns]"
1301
+ ]
1302
+ },
1303
+ "execution_count": 16,
1304
+ "metadata": {},
1305
+ "output_type": "execute_result"
1306
+ }
1307
+ ],
1308
+ "source": [
1309
+ "trainDF3 = trainDF2.rename(columns={'decision': 'label'})\n",
1310
+ "trainDF3['text'] = trainDF3['abstract'] + ' ' + trainDF3['claims']\n",
1311
+ "trainDF3.drop(columns=[\"abstract\",\"claims\"],inplace=True)\n",
1312
+ "trainDF3"
1313
+ ]
1314
+ },
1315
+ {
1316
+ "cell_type": "code",
1317
+ "execution_count": 17,
1318
+ "metadata": {},
1319
+ "outputs": [
1320
+ {
1321
+ "data": {
1322
+ "text/html": [
1323
+ "<div>\n",
1324
+ "<style scoped>\n",
1325
+ " .dataframe tbody tr th:only-of-type {\n",
1326
+ " vertical-align: middle;\n",
1327
+ " }\n",
1328
+ "\n",
1329
+ " .dataframe tbody tr th {\n",
1330
+ " vertical-align: top;\n",
1331
+ " }\n",
1332
+ "\n",
1333
+ " .dataframe thead th {\n",
1334
+ " text-align: right;\n",
1335
+ " }\n",
1336
+ "</style>\n",
1337
+ "<table border=\"1\" class=\"dataframe\">\n",
1338
+ " <thead>\n",
1339
+ " <tr style=\"text-align: right;\">\n",
1340
+ " <th></th>\n",
1341
+ " <th>label</th>\n",
1342
+ " <th>text</th>\n",
1343
+ " </tr>\n",
1344
+ " </thead>\n",
1345
+ " <tbody>\n",
1346
+ " <tr>\n",
1347
+ " <th>0</th>\n",
1348
+ " <td>0</td>\n",
1349
+ " <td>Regimen for the treatment of rosacea include t...</td>\n",
1350
+ " </tr>\n",
1351
+ " <tr>\n",
1352
+ " <th>1</th>\n",
1353
+ " <td>1</td>\n",
1354
+ " <td>A clamp arrangement includes a pair of bracket...</td>\n",
1355
+ " </tr>\n",
1356
+ " <tr>\n",
1357
+ " <th>2</th>\n",
1358
+ " <td>0</td>\n",
1359
+ " <td>A system and method for device action and conf...</td>\n",
1360
+ " </tr>\n",
1361
+ " <tr>\n",
1362
+ " <th>4</th>\n",
1363
+ " <td>0</td>\n",
1364
+ " <td>Systems and methods for managing datasets prod...</td>\n",
1365
+ " </tr>\n",
1366
+ " <tr>\n",
1367
+ " <th>9</th>\n",
1368
+ " <td>1</td>\n",
1369
+ " <td>A scan driving circuit is provided. The scan d...</td>\n",
1370
+ " </tr>\n",
1371
+ " <tr>\n",
1372
+ " <th>...</th>\n",
1373
+ " <td>...</td>\n",
1374
+ " <td>...</td>\n",
1375
+ " </tr>\n",
1376
+ " <tr>\n",
1377
+ " <th>9085</th>\n",
1378
+ " <td>0</td>\n",
1379
+ " <td>The non-rigid gate device as described may be ...</td>\n",
1380
+ " </tr>\n",
1381
+ " <tr>\n",
1382
+ " <th>9090</th>\n",
1383
+ " <td>0</td>\n",
1384
+ " <td>The present invention provides an improved unc...</td>\n",
1385
+ " </tr>\n",
1386
+ " <tr>\n",
1387
+ " <th>9091</th>\n",
1388
+ " <td>1</td>\n",
1389
+ " <td>A method for detecting a software-race conditi...</td>\n",
1390
+ " </tr>\n",
1391
+ " <tr>\n",
1392
+ " <th>9092</th>\n",
1393
+ " <td>1</td>\n",
1394
+ " <td>The present application relates to multi-stage...</td>\n",
1395
+ " </tr>\n",
1396
+ " <tr>\n",
1397
+ " <th>9093</th>\n",
1398
+ " <td>1</td>\n",
1399
+ " <td>A paper feeder includes a housing, a driving u...</td>\n",
1400
+ " </tr>\n",
1401
+ " </tbody>\n",
1402
+ "</table>\n",
1403
+ "<p>4888 rows Γ— 2 columns</p>\n",
1404
+ "</div>"
1405
+ ],
1406
+ "text/plain": [
1407
+ " label text\n",
1408
+ "0 0 Regimen for the treatment of rosacea include t...\n",
1409
+ "1 1 A clamp arrangement includes a pair of bracket...\n",
1410
+ "2 0 A system and method for device action and conf...\n",
1411
+ "4 0 Systems and methods for managing datasets prod...\n",
1412
+ "9 1 A scan driving circuit is provided. The scan d...\n",
1413
+ "... ... ...\n",
1414
+ "9085 0 The non-rigid gate device as described may be ...\n",
1415
+ "9090 0 The present invention provides an improved unc...\n",
1416
+ "9091 1 A method for detecting a software-race conditi...\n",
1417
+ "9092 1 The present application relates to multi-stage...\n",
1418
+ "9093 1 A paper feeder includes a housing, a driving u...\n",
1419
+ "\n",
1420
+ "[4888 rows x 2 columns]"
1421
+ ]
1422
+ },
1423
+ "execution_count": 17,
1424
+ "metadata": {},
1425
+ "output_type": "execute_result"
1426
+ }
1427
+ ],
1428
+ "source": [
1429
+ "valDF3 = valDF2.rename(columns={'decision': 'label'})\n",
1430
+ "valDF3['text'] = valDF3['abstract'] + ' ' + valDF3['claims']\n",
1431
+ "valDF3.drop(columns=[\"abstract\",\"claims\"],inplace=True)\n",
1432
+ "valDF3"
1433
+ ]
1434
+ },
1435
+ {
1436
+ "cell_type": "markdown",
1437
+ "metadata": {},
1438
+ "source": [
1439
+ "We can grab the data for each column so that we have a list of values for training labels, training texts, validation labels, and validation texts."
1440
+ ]
1441
+ },
1442
+ {
1443
+ "cell_type": "code",
1444
+ "execution_count": 18,
1445
+ "metadata": {},
1446
+ "outputs": [],
1447
+ "source": [
1448
+ "trainLabels = trainDF3[\"label\"].tolist()\n",
1449
+ "trainText = trainDF3[\"text\"].tolist()\n",
1450
+ "\n",
1451
+ "valLabels = valDF3[\"label\"].tolist()\n",
1452
+ "valText = valDF3[\"text\"].tolist()"
1453
+ ]
1454
+ },
1455
+ {
1456
+ "cell_type": "markdown",
1457
+ "metadata": {},
1458
+ "source": [
1459
+ "## Loading the Trainer\n",
1460
+ "\n",
1461
+ "Now we can start training! This time, we will just go with `distilbert-base-uncased` for simplicity."
1462
+ ]
1463
+ },
1464
+ {
1465
+ "cell_type": "code",
1466
+ "execution_count": 19,
1467
+ "metadata": {},
1468
+ "outputs": [
1469
+ {
1470
+ "name": "stdout",
1471
+ "output_type": "stream",
1472
+ "text": [
1473
+ "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (2.0.0)\n",
1474
+ "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /opt/conda/lib/python3.10/site-packages (from torch) (11.7.4.91)\n",
1475
+ "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /opt/conda/lib/python3.10/site-packages (from torch) (11.7.91)\n",
1476
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch) (3.1.2)\n",
1477
+ "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch) (4.4.0)\n",
1478
+ "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /opt/conda/lib/python3.10/site-packages (from torch) (10.2.10.91)\n",
1479
+ "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /opt/conda/lib/python3.10/site-packages (from torch) (11.4.0.1)\n",
1480
+ "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.10/site-packages (from torch) (11.10.3.66)\n",
1481
+ "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /opt/conda/lib/python3.10/site-packages (from torch) (10.9.0.58)\n",
1482
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch) (1.11.1)\n",
1483
+ "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.10/site-packages (from torch) (11.7.99)\n",
1484
+ "Requirement already satisfied: triton==2.0.0 in /opt/conda/lib/python3.10/site-packages (from torch) (2.0.0)\n",
1485
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch) (3.12.0)\n",
1486
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch) (2.8.7)\n",
1487
+ "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /opt/conda/lib/python3.10/site-packages (from torch) (2.14.3)\n",
1488
+ "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /opt/conda/lib/python3.10/site-packages (from torch) (11.7.101)\n",
1489
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.10/site-packages (from torch) (11.7.99)\n",
1490
+ "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.10/site-packages (from torch) (8.5.0.96)\n",
1491
+ "Requirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (65.4.1)\n",
1492
+ "Requirement already satisfied: wheel in /opt/conda/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (0.37.1)\n",
1493
+ "Requirement already satisfied: cmake in /opt/conda/lib/python3.10/site-packages (from triton==2.0.0->torch) (3.26.3)\n",
1494
+ "Requirement already satisfied: lit in /opt/conda/lib/python3.10/site-packages (from triton==2.0.0->torch) (16.0.1)\n",
1495
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch) (2.1.1)\n",
1496
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch) (1.2.1)\n",
1497
+ "Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (4.28.1)\n",
1498
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.13.4)\n",
1499
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (6.0)\n",
1500
+ "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers) (1.23.3)\n",
1501
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers) (3.12.0)\n",
1502
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers) (2023.3.23)\n",
1503
+ "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers) (4.64.1)\n",
1504
+ "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers) (2.28.1)\n",
1505
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
1506
+ "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers) (21.3)\n",
1507
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.4.0)\n",
1508
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->transformers) (3.0.9)\n",
1509
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (1.26.11)\n",
1510
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (2022.9.24)\n",
1511
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (3.4)\n",
1512
+ "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (2.1.1)\n"
1513
+ ]
1514
+ }
1515
+ ],
1516
+ "source": [
1517
+ "!pip install torch\n",
1518
+ "!pip install transformers"
1519
+ ]
1520
+ },
1521
+ {
1522
+ "cell_type": "code",
1523
+ "execution_count": 20,
1524
+ "metadata": {},
1525
+ "outputs": [],
1526
+ "source": [
1527
+ "import torch\n",
1528
+ "from torch.utils.data import Dataset\n",
1529
+ "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n",
1530
+ "from transformers import Trainer, TrainingArguments"
1531
+ ]
1532
+ },
1533
+ {
1534
+ "cell_type": "code",
1535
+ "execution_count": 21,
1536
+ "metadata": {},
1537
+ "outputs": [],
1538
+ "source": [
1539
+ "model_name = \"distilbert-base-uncased\"\n",
1540
+ "class USPTODataset(Dataset):\n",
1541
+ " def __init__(self, encodings, labels):\n",
1542
+ " self.encodings = encodings\n",
1543
+ " self.labels = labels\n",
1544
+ " def __getitem__(self, idx):\n",
1545
+ " item = {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}\n",
1546
+ " item['labels'] = torch.tensor(self.labels[idx])\n",
1547
+ " return item\n",
1548
+ " def __len__(self):\n",
1549
+ " return len(self.labels)\n"
1550
+ ]
1551
+ },
1552
+ {
1553
+ "cell_type": "code",
1554
+ "execution_count": 22,
1555
+ "metadata": {},
1556
+ "outputs": [],
1557
+ "source": [
1558
+ "tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)"
1559
+ ]
1560
+ },
1561
+ {
1562
+ "cell_type": "code",
1563
+ "execution_count": null,
1564
+ "metadata": {},
1565
+ "outputs": [],
1566
+ "source": [
1567
+ "train_encodings = tokenizer(trainText, truncation=True, padding=True)\n",
1568
+ "val_encodings = tokenizer(valText, truncation=True, padding=True)\n",
1569
+ "\n",
1570
+ "train_dataset = USPTODataset(train_encodings, trainLabels)\n",
1571
+ "val_dataset = USPTODataset(val_encodings, valLabels)\n",
1572
+ "\n",
1573
+ "train_args = TrainingArguments(\n",
1574
+ " output_dir=\"./results\",\n",
1575
+ " num_train_epochs=2,\n",
1576
+ " per_device_train_batch_size=16,\n",
1577
+ " per_device_eval_batch_size=64,\n",
1578
+ " warmup_steps=500,\n",
1579
+ " learning_rate=5e-5,\n",
1580
+ " weight_decay=0.01,\n",
1581
+ " logging_dir=\"./logs\",\n",
1582
+ " logging_steps=10\n",
1583
+ ")"
1584
+ ]
1585
+ },
1586
+ {
1587
+ "cell_type": "code",
1588
+ "execution_count": null,
1589
+ "metadata": {},
1590
+ "outputs": [],
1591
+ "source": []
1592
+ },
1593
+ {
1594
+ "cell_type": "code",
1595
+ "execution_count": null,
1596
+ "metadata": {},
1597
+ "outputs": [],
1598
+ "source": []
1599
+ },
1600
+ {
1601
+ "cell_type": "code",
1602
+ "execution_count": null,
1603
+ "metadata": {},
1604
+ "outputs": [],
1605
+ "source": []
1606
+ },
1607
+ {
1608
+ "cell_type": "code",
1609
+ "execution_count": null,
1610
+ "metadata": {},
1611
+ "outputs": [],
1612
+ "source": []
1613
+ },
1614
  {
1615
  "cell_type": "code",
1616
  "execution_count": null,
src/train.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import json
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
9
+ from transformers import Trainer, TrainingArguments, AdamW
10
+
11
+ model_name = "distilbert-base-uncased"
12
+
13
+ class USPTODataset(Dataset):
14
+ def __init__(self, encodings, labels):
15
+ self.encodings = encodings
16
+ self.labels = labels
17
+ def __getitem__(self, idx):
18
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
19
+ item['labels'] = torch.tensor(self.labels[idx])
20
+ return item
21
+ def __len__(self):
22
+ return len(self.labels)
23
+
24
+ def LoadDataset():
25
+ print("=== LOADING THE DATASET ===")
26
+ # Extracting the dataset, filtering only for Jan. 2016
27
+ dataset_dict = load_dataset('HUPD/hupd',
28
+ name='sample',
29
+ data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
30
+ icpr_label=None,
31
+ train_filing_start_date='2016-01-01',
32
+ train_filing_end_date='2016-01-21',
33
+ val_filing_start_date='2016-01-22',
34
+ val_filing_end_date='2016-01-31',
35
+ )
36
+
37
+ print("Separating between training and validation data")
38
+ df_train = pd.DataFrame(dataset_dict['train'] )
39
+ df_val = pd.DataFrame(dataset_dict['validation'] )
40
+
41
+
42
+ print("=== PRE-PROCESSING THE DATASET ===")
43
+ #We are interested in the following columns:
44
+ # - Abstract
45
+ # - Claims
46
+ # - Decision <- our `y`
47
+ # Let's preprocess them both out of our training and validation data
48
+ # Also, consider that the "Decision" column has three types of values: "Accepted", "Rejected", and "Pending". To remove unecessary baggage, we will be only looking for "Accepted" and "Rejected".
49
+
50
+ necessary_columns = ["abstract","claims","decision"]
51
+ output_values = ['ACCEPTED','REJECTED']
52
+
53
+ print("Dropping unused columns")
54
+ trainFeaturesToDrop = [col for col in list(df_train.columns) if col not in necessary_columns]
55
+ trainDF = df_train.dropna()
56
+ trainDF.drop(columns=trainFeaturesToDrop, inplace=True)
57
+ trainDF = trainDF[trainDF['decision'].isin(output_values)]
58
+ valFeaturesToDrop = [col for col in list(df_val.columns) if col not in necessary_columns]
59
+ valDF = df_val.dropna()
60
+ valDF.drop(columns=valFeaturesToDrop, inplace=True)
61
+ valDF = valDF[valDF['decision'].isin(output_values)]
62
+
63
+ # We need to replace the values in the `decision` column to numerical representations. ]
64
+ # We will set "ACCEPTED" as `1` and "REJECTED" as `0`.
65
+ print("Replacing values in `decision` column")
66
+ yKey = {"ACCEPTED":1,"REJECTED":0}
67
+ trainDF2 = trainDF.replace({"decision": yKey})
68
+ valDF2 = valDF.replace({"decision": yKey})
69
+
70
+ # We combine the `abstract` and `claims` columns into a single `text` column.
71
+ # We also re-label the `decision` column to `label`.
72
+ print("Combining columns and renaming `decision` to `label`")
73
+ trainDF3 = trainDF2.rename(columns={'decision': 'label'})
74
+ trainDF3['text'] = trainDF3['abstract'] + ' ' + trainDF3['claims']
75
+ trainDF3.drop(columns=["abstract","claims"],inplace=True)
76
+
77
+ valDF3 = valDF2.rename(columns={'decision': 'label'})
78
+ valDF3['text'] = valDF3['abstract'] + ' ' + valDF3['claims']
79
+ valDF3.drop(columns=["abstract","claims"],inplace=True)
80
+
81
+ # We can grab the data for each column so that we have a list of values for training labels,
82
+ # training texts, validation labels, and validation texts.
83
+ print("Extracting label and text data from dataframes")
84
+ trainData = {
85
+ "labels":trainDF3["label"].tolist(),
86
+ "text":trainDF3["text"].tolist()
87
+ }
88
+ valData = {
89
+ "labels":valDF3["label"].tolist(),
90
+ "text":valDF3["text"].tolist()
91
+ }
92
+ print(f'TRAINING:\t# labels: {len(trainData["labels"])}\t# texts: {len(trainData["text"])}')
93
+ print(f'VALID:\t# labels: {len(valData["labels"])}\t# texts: {len(valData["text"])}')
94
+
95
+ if not os.path.exists("./data"):
96
+ os.makedirs('./data')
97
+
98
+ with open("./data/train.json", "w") as outfile:
99
+ json.dump(trainData, outfile, indent=2)
100
+ with open("./data/val.json", "w") as outfile:
101
+ json.dump(valData, outfile, indent=2)
102
+
103
+ return trainData, valData
104
+
105
+ def main():
106
+ trainDataPath = "./data/train.json"
107
+ valDataPath = "./data/val.json"
108
+ trainData = None
109
+ valData = None
110
+
111
+ if os.path.exists(trainDataPath) and os.path.exists(valDataPath):
112
+ ftrain = open(trainDataPath)
113
+ trainData = json.load(ftrain)
114
+ ftrain.close()
115
+ fval = open(valDataPath)
116
+ valData = json.load(fval)
117
+ fval.close()
118
+ else:
119
+ trainData, valData = LoadDataset()
120
+
121
+ print(len(trainData["labels"]), len(trainData["text"]), len(valData["labels"]), len(valData["text"]))
122
+
123
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
124
+ train_encodings = tokenizer(trainData["text"], truncation=True, padding=True)
125
+ val_encodings = tokenizer(valData["text"], truncation=True, padding=True)
126
+
127
+ train_dataset = USPTODataset(train_encodings, trainData["labels"])
128
+ val_dataset = USPTODataset(val_encodings, valData["labels"])
129
+
130
+ train_args = TrainingArguments(
131
+ output_dir="./results",
132
+ num_train_epochs=2,
133
+ per_device_train_batch_size=16,
134
+ per_device_eval_batch_size=64,
135
+ warmup_steps=500,
136
+ learning_rate=5e-5,
137
+ weight_decay=0.01,
138
+ logging_dir="./logs",
139
+ logging_steps=10
140
+ )
141
+
142
+ model = DistilBertForSequenceClassification.from_pretrained(model_name)
143
+ trainer = Trainer(
144
+ model=model,
145
+ args=train_args,
146
+ train_dataset=train_dataset,
147
+ eval_dataset=val_dataset
148
+ )
149
+ trainer.train()
150
+
151
+ if __name__ == "__main__":
152
+ main()