hiverlab-nicholastkb commited on
Commit
a54ef97
1 Parent(s): f77931d

added demo.ipynb

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. allDataShort2.json +0 -0
  3. demo.ipynb +929 -0
README.md CHANGED
@@ -41,7 +41,7 @@ The model is based on the [SciBERT](https://github.com/allenai/scibert) architec
41
  You can access an interactive web interface for querying the fine-tuned LGL model [here](spacelink). If you prefer to load the model yourself, you can check out [Installation](#installation) below.
42
 
43
  ## Installation
44
- To use LGL, you need to install the required dependencies and download the model files. Follow the steps below to set up the environment:
45
 
46
  1. Clone this repository to your local machine.
47
  1.1 If you do not have Python installed, download python via the official sources. Anaconda is recommended if you use scientific packages often.
@@ -51,7 +51,7 @@ If using anaconda, after installation setup a new conda environment via the foll
51
 
52
  2. Activate your venv/ conda env (if using) and install the required Python packages using `pip`:
53
 
54
- ```pip install -r requirements.txt```
55
 
56
  3. To utilize the fine-tuned NER model for recognizing drugs, genes, and diseases, you can open `demo.ipynb` in Jupyter Lab by starting Jupyter Lab via ```jupyter lab```. The script takes text input as a string and returns the identified entities along with their respective labels.
57
 
 
41
  You can access an interactive web interface for querying the fine-tuned LGL model [here](spacelink). If you prefer to load the model yourself, you can check out [Installation](#installation) below.
42
 
43
  ## Installation
44
+ If you prefer to run LGL locally or conduct further fine-tuning, you need to install the required dependencies and download the model files. Follow the steps below to set up the environment:
45
 
46
  1. Clone this repository to your local machine.
47
  1.1 If you do not have Python installed, download python via the official sources. Anaconda is recommended if you use scientific packages often.
 
51
 
52
  2. Activate your venv/ conda env (if using) and install the required Python packages using `pip`:
53
 
54
+ ```pip install -r requirements_local.txt```
55
 
56
  3. To utilize the fine-tuned NER model for recognizing drugs, genes, and diseases, you can open `demo.ipynb` in Jupyter Lab by starting Jupyter Lab via ```jupyter lab```. The script takes text input as a string and returns the identified entities along with their respective labels.
57
 
allDataShort2.json ADDED
The diff for this file is too large to render. See raw diff
 
demo.ipynb ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "0205ad7a",
6
+ "metadata": {},
7
+ "source": [
8
+ "# SciBERT Fine-Tuning \n",
9
+ "\n",
10
+ "---\n",
11
+ "\n",
12
+ "In this notebook, we use the 🤗 `transformers` library to fine-tune the `allenai/scibert_scivocab_uncased` model on various datasets. The goal is for the fine-tuned model to perform Named Entity Recognition of Drugs, Diseases, and Genes. "
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "f94afa73-cca7-4cd5-8d74-8971f84eddc8",
18
+ "metadata": {},
19
+ "source": [
20
+ "!pip install --target=$\"modules\" datasets transformers\n",
21
+ "! pip install --target=$\"modules\" seqeval \n",
22
+ "! pip install --target=$\"modules\" spacy "
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "849c4d25-4f4e-4e18-b17e-665d4abf90d6",
28
+ "metadata": {
29
+ "tags": []
30
+ },
31
+ "source": [
32
+ "! pip install --target=$\"modules\" evaluate"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "11034662-5c9f-45be-99d2-634d311a7d60",
38
+ "metadata": {},
39
+ "source": [
40
+ "! pip install --target=$\"modules\" bioc"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 1,
46
+ "id": "2f674f50",
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stderr",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "C:\\Users\\nicho\\miniconda3\\envs\\LGL1\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
54
+ " from .autonotebook import tqdm as notebook_tqdm\n"
55
+ ]
56
+ }
57
+ ],
58
+ "source": [
59
+ "from datasets import Dataset, ClassLabel, Sequence, load_dataset, load_metric\n",
60
+ "import numpy as np\n",
61
+ "import pandas as pd\n",
62
+ "import bioc\n",
63
+ "from spacy import displacy\n",
64
+ "import transformers\n",
65
+ "#import evaluate\n",
66
+ "from transformers import (AutoModelForTokenClassification, \n",
67
+ " AutoTokenizer, \n",
68
+ " DataCollatorForTokenClassification,\n",
69
+ " pipeline,\n",
70
+ " TrainingArguments, \n",
71
+ " Trainer)"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 2,
77
+ "id": "9c3c5157",
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "4.31.0\n"
85
+ ]
86
+ }
87
+ ],
88
+ "source": [
89
+ "# confirm version > 4.11.0\n",
90
+ "print(transformers.__version__)"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "id": "05ebaa92-58d0-40d4-ade9-a2ed6053827d",
96
+ "metadata": {},
97
+ "source": [
98
+ "## Reading files"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 3,
104
+ "id": "101c50e9",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "# no train-test provided, so we create our own\n",
109
+ "cons_dataset = load_dataset(\"json\", data_files=\"./allDataShort2.json\")\n",
110
+ "#cons_dataset = datasets\n",
111
+ "cons_dataset = cons_dataset[\"train\"].train_test_split(test_size=0.001)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 4,
117
+ "id": "db6c1a35-ba5f-4976-aa27-14293fe676e8",
118
+ "metadata": {},
119
+ "outputs": [
120
+ {
121
+ "data": {
122
+ "text/plain": [
123
+ "DatasetDict({\n",
124
+ " train: Dataset({\n",
125
+ " features: ['index', 'text', 'gene', 'gene_indices_start', 'gene_indices_end', 'disease', 'disease_indices_start', 'disease_indices_end', 'drugs', 'drug_indices_start', 'drug_indices_end'],\n",
126
+ " num_rows: 6554\n",
127
+ " })\n",
128
+ " test: Dataset({\n",
129
+ " features: ['index', 'text', 'gene', 'gene_indices_start', 'gene_indices_end', 'disease', 'disease_indices_start', 'disease_indices_end', 'drugs', 'drug_indices_start', 'drug_indices_end'],\n",
130
+ " num_rows: 7\n",
131
+ " })\n",
132
+ "})"
133
+ ]
134
+ },
135
+ "execution_count": 4,
136
+ "metadata": {},
137
+ "output_type": "execute_result"
138
+ }
139
+ ],
140
+ "source": [
141
+ "cons_dataset"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "markdown",
146
+ "id": "a9bfbe10",
147
+ "metadata": {},
148
+ "source": [
149
+ "---\n",
150
+ "## Token Labeling\n",
151
+ "\n",
152
+ "Finally, we can label each token with its entity. We use BIO tagging on two entities, `DRUG` and `EFFECT`. This results in five possible classes for each token:\n",
153
+ "\n",
154
+ "* `O` - outside any entity we care about\n",
155
+ "* `B-DRUG` - the beginning of a `DRUG` entity\n",
156
+ "* `I-DRUG` - inside a `DRUG` entity\n",
157
+ "* `B-EFFECT` - the beginning of an `EFFECT` entity\n",
158
+ "* `I-EFFECT` - inside an `EFFECT` entity"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 5,
164
+ "id": "89d0ec15",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "label_list = ['O', 'B-DRUG', 'I-DRUG', 'B-DISEASE', 'I-DISEASE', 'B-GENE', 'I-GENE']\n",
169
+ "\n",
170
+ "custom_seq = Sequence(feature=ClassLabel(num_classes=7, \n",
171
+ " names=label_list,\n",
172
+ " names_file=None, id=None), length=-1, id=None)\n",
173
+ "\n",
174
+ "cons_dataset[\"train\"].features[\"ner_tags\"] = custom_seq\n",
175
+ "cons_dataset[\"test\"].features[\"ner_tags\"] = custom_seq"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 6,
181
+ "id": "48c50247",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "#model_checkpoint = \"allenai/scibert_scivocab_uncased\"\n",
186
+ "model_checkpoint = './trainedSB2'\n",
187
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 7,
193
+ "id": "c9481628",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "from math import isnan\n",
198
+ "def generate_row_labels(row, verbose=False):\n",
199
+ " \"\"\" Given a row from the consolidated dataset, \n",
200
+ " generates BIO tags for drug and effect entities. \n",
201
+ " \"\"\"\n",
202
+ "\n",
203
+ " text = row[\"text\"]\n",
204
+ "\n",
205
+ " labels = []\n",
206
+ " label = \"O\"\n",
207
+ " prefix = \"\"\n",
208
+ " \n",
209
+ " # while iterating through tokens, increment to traverse all drug and effect spans\n",
210
+ " drug_index = 0\n",
211
+ " disease_index = 0\n",
212
+ " gene_index=0\n",
213
+ " \n",
214
+ " tokens = tokenizer(text, return_offsets_mapping=True)\n",
215
+ "\n",
216
+ " for n in range(len(tokens[\"input_ids\"])):\n",
217
+ " offset_start, offset_end = tokens[\"offset_mapping\"][n]\n",
218
+ "\n",
219
+ " # should only happen for [CLS] and [SEP]\n",
220
+ " if offset_end - offset_start == 0:\n",
221
+ " labels.append(-100)\n",
222
+ " continue\n",
223
+ " \n",
224
+ " if (type(row[\"drug_indices_start\"]) == list) and drug_index < len(row[\"drug_indices_start\"]) and offset_start == row[\"drug_indices_start\"][drug_index]:\n",
225
+ " label = \"DRUG\"\n",
226
+ " prefix = \"B-\"\n",
227
+ "\n",
228
+ " elif (type(row[\"disease_indices_start\"]) == list) and disease_index < len(row[\"disease_indices_start\"]) and offset_start == row[\"disease_indices_start\"][disease_index]:\n",
229
+ " label = \"DISEASE\"\n",
230
+ " prefix = \"B-\"\n",
231
+ " \n",
232
+ " elif (type(row[\"gene_indices_start\"]) == list) and gene_index < len(row[\"gene_indices_start\"]) and offset_start == row[\"gene_indices_start\"][gene_index]:\n",
233
+ " label = \"GENE\"\n",
234
+ " prefix = \"B-\"\n",
235
+ " \n",
236
+ " labels.append(label_list.index(f\"{prefix}{label}\"))\n",
237
+ " \n",
238
+ " if (type(row[\"drug_indices_end\"]) == list) and drug_index < len(row[\"drug_indices_end\"]) and offset_end == row[\"drug_indices_end\"][drug_index]:\n",
239
+ " label = \"O\"\n",
240
+ " prefix = \"\"\n",
241
+ " drug_index += 1\n",
242
+ " \n",
243
+ " elif (type(row[\"disease_indices_end\"]) == list) and disease_index < len(row[\"disease_indices_end\"]) and offset_end == row[\"disease_indices_end\"][disease_index]:\n",
244
+ " label = \"O\"\n",
245
+ " prefix = \"\"\n",
246
+ " disease_index += 1\n",
247
+ " \n",
248
+ " elif (type(row[\"gene_indices_end\"]) == list) and gene_index < len(row[\"gene_indices_end\"]) and offset_end == row[\"gene_indices_end\"][gene_index]:\n",
249
+ " label = \"O\"\n",
250
+ " prefix = \"\"\n",
251
+ " gene_index += 1\n",
252
+ "\n",
253
+ " # need to transition \"inside\" if we just entered an entity\n",
254
+ " if prefix == \"B-\":\n",
255
+ " prefix = \"I-\"\n",
256
+ " \n",
257
+ " if verbose:\n",
258
+ " print(f\"{row}\\n\")\n",
259
+ " orig = tokenizer.convert_ids_to_tokens(tokens[\"input_ids\"])\n",
260
+ " for n in range(len(labels)):\n",
261
+ " print(orig[n], labels[n])\n",
262
+ " tokens[\"labels\"] = labels\n",
263
+ " \n",
264
+ " return tokens\n"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": 8,
270
+ "id": "4783acc7",
271
+ "metadata": {},
272
+ "outputs": [
273
+ {
274
+ "name": "stderr",
275
+ "output_type": "stream",
276
+ "text": [
277
+ "Map: 100%|█████████████████████████████████████████████████████████████████| 6554/6554 [00:14<00:00, 452.94 examples/s]\n",
278
+ "Map: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 102.88 examples/s]\n"
279
+ ]
280
+ }
281
+ ],
282
+ "source": [
283
+ "labeled_dataset = cons_dataset.map(generate_row_labels)"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "id": "bb73c47b",
289
+ "metadata": {
290
+ "jp-MarkdownHeadingCollapsed": true
291
+ },
292
+ "source": [
293
+ "---\n",
294
+ "## SciBERT Model Fine-Tuning\n",
295
+ "\n",
296
+ "We are now ready to fine-tune the SciBERT model on our dataset. This section is modified from the following 🤗 notebook provided here: https://github.com/huggingface/notebooks/blob/master/examples/token_classification.ipynb\n"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": 9,
302
+ "id": "08cb693f",
303
+ "metadata": {},
304
+ "outputs": [],
305
+ "source": [
306
+ "task = \"ner\" # Should be one of \"ner\", \"pos\" or \"chunk\"\n",
307
+ "#model_checkpoint = \"allenai/scibert_scivocab_uncased\"\n",
308
+ "batch_size = 16"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": 10,
314
+ "id": "d7128ee4",
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 48,
324
+ "id": "ff2b60fc",
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": [
328
+ "#!pip install transformers[torch]\n",
329
+ "model_name = model_checkpoint.split(\"/\")[-1]\n",
330
+ "args = TrainingArguments(\n",
331
+ " f\"{model_name}-finetuned-{task}\",\n",
332
+ " evaluation_strategy = \"epoch\",\n",
333
+ " learning_rate=1e-5,\n",
334
+ " per_device_train_batch_size=batch_size,\n",
335
+ " per_device_eval_batch_size=batch_size,\n",
336
+ " num_train_epochs=5,\n",
337
+ " weight_decay=0.05,\n",
338
+ " logging_steps=1\n",
339
+ ")"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": 49,
345
+ "id": "f1807139",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "data_collator = DataCollatorForTokenClassification(tokenizer)"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": 50,
355
+ "id": "560f0969",
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "#!pip install seqeval\n",
360
+ "#module = evaluate.load(\"lvwerra/element_count\")\n",
361
+ "metric = load_metric(\"seqeval\")"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 51,
367
+ "id": "156e1cf4",
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "def compute_metrics(p):\n",
372
+ " predictions, labels = p\n",
373
+ " predictions = np.argmax(predictions, axis=2)\n",
374
+ "\n",
375
+ " # Remove ignored index (special tokens)\n",
376
+ " true_predictions = [\n",
377
+ " [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
378
+ " for prediction, label in zip(predictions, labels)\n",
379
+ " ]\n",
380
+ " true_labels = [\n",
381
+ " [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
382
+ " for prediction, label in zip(predictions, labels)\n",
383
+ " ]\n",
384
+ "\n",
385
+ " results = metric.compute(predictions=true_predictions, references=true_labels)\n",
386
+ " return {\n",
387
+ " \"precision\": results[\"overall_precision\"],\n",
388
+ " \"recall\": results[\"overall_recall\"],\n",
389
+ " \"f1\": results[\"overall_f1\"],\n",
390
+ " \"accuracy\": results[\"overall_accuracy\"],\n",
391
+ " }"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 52,
397
+ "id": "bd9c7f71",
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "trainer = Trainer(\n",
402
+ " model,\n",
403
+ " args,\n",
404
+ " train_dataset=labeled_dataset[\"train\"],\n",
405
+ " eval_dataset=labeled_dataset[\"test\"],\n",
406
+ " data_collator=data_collator,\n",
407
+ " tokenizer=tokenizer,\n",
408
+ " compute_metrics=compute_metrics, \n",
409
+ "\n",
410
+ ")\n"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": 53,
416
+ "id": "e644de18",
417
+ "metadata": {},
418
+ "outputs": [
419
+ {
420
+ "name": "stderr",
421
+ "output_type": "stream",
422
+ "text": [
423
+ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
424
+ ]
425
+ },
426
+ {
427
+ "data": {
428
+ "text/html": [],
429
+ "text/plain": [
430
+ "<IPython.core.display.HTML object>"
431
+ ]
432
+ },
433
+ "metadata": {},
434
+ "output_type": "display_data"
435
+ },
436
+ {
437
+ "data": {
438
+ "text/plain": [
439
+ "{'DRUG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 3},\n",
440
+ " 'GENE': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 13},\n",
441
+ " 'overall_precision': 1.0,\n",
442
+ " 'overall_recall': 1.0,\n",
443
+ " 'overall_f1': 1.0,\n",
444
+ " 'overall_accuracy': 1.0}"
445
+ ]
446
+ },
447
+ "execution_count": 53,
448
+ "metadata": {},
449
+ "output_type": "execute_result"
450
+ }
451
+ ],
452
+ "source": [
453
+ "predictions, labels, _ = trainer.predict(labeled_dataset[\"test\"])\n",
454
+ "predictions = np.argmax(predictions, axis=2)\n",
455
+ "\n",
456
+ "# Remove ignored index (special tokens)\n",
457
+ "true_predictions = [\n",
458
+ " [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
459
+ " for prediction, label in zip(predictions, labels)\n",
460
+ "]\n",
461
+ "true_labels = [\n",
462
+ " [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
463
+ " for prediction, label in zip(predictions, labels)\n",
464
+ "]\n",
465
+ "\n",
466
+ "results = metric.compute(predictions=true_predictions, references=true_labels)\n",
467
+ "results"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "markdown",
472
+ "id": "5c7e4324",
473
+ "metadata": {},
474
+ "source": [
475
+ "---\n",
476
+ "## See Model Outputs\n",
477
+ "\n",
478
+ "We load our fine-tuned model into a `pipeline` object to run arbitrary input against it."
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": 11,
484
+ "id": "51e55f32",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "effect_ner_model = pipeline(task=\"ner\", model=model, tokenizer=tokenizer)"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": 12,
494
+ "id": "387dd2af",
495
+ "metadata": {
496
+ "tags": []
497
+ },
498
+ "outputs": [
499
+ {
500
+ "name": "stderr",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
504
+ ]
505
+ },
506
+ {
507
+ "data": {
508
+ "text/plain": [
509
+ "[{'entity': 'LABEL_0',\n",
510
+ " 'score': 0.9997309,\n",
511
+ " 'index': 1,\n",
512
+ " 'word': 'clinical',\n",
513
+ " 'start': 0,\n",
514
+ " 'end': 8},\n",
515
+ " {'entity': 'LABEL_0',\n",
516
+ " 'score': 0.9997073,\n",
517
+ " 'index': 2,\n",
518
+ " 'word': 'studies',\n",
519
+ " 'start': 9,\n",
520
+ " 'end': 16},\n",
521
+ " {'entity': 'LABEL_0',\n",
522
+ " 'score': 0.9996835,\n",
523
+ " 'index': 3,\n",
524
+ " 'word': 'have',\n",
525
+ " 'start': 17,\n",
526
+ " 'end': 21},\n",
527
+ " {'entity': 'LABEL_0',\n",
528
+ " 'score': 0.99972457,\n",
529
+ " 'index': 4,\n",
530
+ " 'word': 'indicated',\n",
531
+ " 'start': 22,\n",
532
+ " 'end': 31},\n",
533
+ " {'entity': 'LABEL_0',\n",
534
+ " 'score': 0.9997483,\n",
535
+ " 'index': 5,\n",
536
+ " 'word': 'the',\n",
537
+ " 'start': 32,\n",
538
+ " 'end': 35},\n",
539
+ " {'entity': 'LABEL_0',\n",
540
+ " 'score': 0.9997286,\n",
541
+ " 'index': 6,\n",
542
+ " 'word': 'following',\n",
543
+ " 'start': 36,\n",
544
+ " 'end': 45},\n",
545
+ " {'entity': 'LABEL_0',\n",
546
+ " 'score': 0.9997805,\n",
547
+ " 'index': 7,\n",
548
+ " 'word': 'relative',\n",
549
+ " 'start': 46,\n",
550
+ " 'end': 54},\n",
551
+ " {'entity': 'LABEL_0',\n",
552
+ " 'score': 0.9997863,\n",
553
+ " 'index': 8,\n",
554
+ " 'word': 'potency',\n",
555
+ " 'start': 55,\n",
556
+ " 'end': 62},\n",
557
+ " {'entity': 'LABEL_0',\n",
558
+ " 'score': 0.9997502,\n",
559
+ " 'index': 9,\n",
560
+ " 'word': 'differences',\n",
561
+ " 'start': 63,\n",
562
+ " 'end': 74},\n",
563
+ " {'entity': 'LABEL_0',\n",
564
+ " 'score': 0.99968636,\n",
565
+ " 'index': 10,\n",
566
+ " 'word': ':',\n",
567
+ " 'start': 75,\n",
568
+ " 'end': 76},\n",
569
+ " {'entity': 'LABEL_1',\n",
570
+ " 'score': 0.9793291,\n",
571
+ " 'index': 11,\n",
572
+ " 'word': 'flu',\n",
573
+ " 'start': 77,\n",
574
+ " 'end': 80},\n",
575
+ " {'entity': 'LABEL_2',\n",
576
+ " 'score': 0.98558635,\n",
577
+ " 'index': 12,\n",
578
+ " 'word': '##ticas',\n",
579
+ " 'start': 80,\n",
580
+ " 'end': 85},\n",
581
+ " {'entity': 'LABEL_2',\n",
582
+ " 'score': 0.99234176,\n",
583
+ " 'index': 13,\n",
584
+ " 'word': '##one',\n",
585
+ " 'start': 85,\n",
586
+ " 'end': 88},\n",
587
+ " {'entity': 'LABEL_2',\n",
588
+ " 'score': 0.6478919,\n",
589
+ " 'index': 14,\n",
590
+ " 'word': 'prop',\n",
591
+ " 'start': 89,\n",
592
+ " 'end': 93},\n",
593
+ " {'entity': 'LABEL_2',\n",
594
+ " 'score': 0.8109712,\n",
595
+ " 'index': 15,\n",
596
+ " 'word': '##ionate',\n",
597
+ " 'start': 93,\n",
598
+ " 'end': 99},\n",
599
+ " {'entity': 'LABEL_0',\n",
600
+ " 'score': 0.9969921,\n",
601
+ " 'index': 16,\n",
602
+ " 'word': '>',\n",
603
+ " 'start': 100,\n",
604
+ " 'end': 101},\n",
605
+ " {'entity': 'LABEL_1',\n",
606
+ " 'score': 0.9772372,\n",
607
+ " 'index': 17,\n",
608
+ " 'word': 'bud',\n",
609
+ " 'start': 102,\n",
610
+ " 'end': 105},\n",
611
+ " {'entity': 'LABEL_2',\n",
612
+ " 'score': 0.9892029,\n",
613
+ " 'index': 18,\n",
614
+ " 'word': '##es',\n",
615
+ " 'start': 105,\n",
616
+ " 'end': 107},\n",
617
+ " {'entity': 'LABEL_2',\n",
618
+ " 'score': 0.9902013,\n",
619
+ " 'index': 19,\n",
620
+ " 'word': '##oni',\n",
621
+ " 'start': 107,\n",
622
+ " 'end': 110},\n",
623
+ " {'entity': 'LABEL_2',\n",
624
+ " 'score': 0.9839754,\n",
625
+ " 'index': 20,\n",
626
+ " 'word': '##de',\n",
627
+ " 'start': 110,\n",
628
+ " 'end': 112},\n",
629
+ " {'entity': 'LABEL_0',\n",
630
+ " 'score': 0.99758625,\n",
631
+ " 'index': 21,\n",
632
+ " 'word': '=',\n",
633
+ " 'start': 113,\n",
634
+ " 'end': 114},\n",
635
+ " {'entity': 'LABEL_1',\n",
636
+ " 'score': 0.9797254,\n",
637
+ " 'index': 22,\n",
638
+ " 'word': 'bec',\n",
639
+ " 'start': 115,\n",
640
+ " 'end': 118},\n",
641
+ " {'entity': 'LABEL_2',\n",
642
+ " 'score': 0.99061537,\n",
643
+ " 'index': 23,\n",
644
+ " 'word': '##lo',\n",
645
+ " 'start': 118,\n",
646
+ " 'end': 120},\n",
647
+ " {'entity': 'LABEL_2',\n",
648
+ " 'score': 0.9920591,\n",
649
+ " 'index': 24,\n",
650
+ " 'word': '##meth',\n",
651
+ " 'start': 120,\n",
652
+ " 'end': 124},\n",
653
+ " {'entity': 'LABEL_2',\n",
654
+ " 'score': 0.98958415,\n",
655
+ " 'index': 25,\n",
656
+ " 'word': '##ason',\n",
657
+ " 'start': 124,\n",
658
+ " 'end': 128},\n",
659
+ " {'entity': 'LABEL_2',\n",
660
+ " 'score': 0.98872375,\n",
661
+ " 'index': 26,\n",
662
+ " 'word': '##e',\n",
663
+ " 'start': 128,\n",
664
+ " 'end': 129},\n",
665
+ " {'entity': 'LABEL_2',\n",
666
+ " 'score': 0.5297541,\n",
667
+ " 'index': 27,\n",
668
+ " 'word': 'dip',\n",
669
+ " 'start': 130,\n",
670
+ " 'end': 133},\n",
671
+ " {'entity': 'LABEL_2',\n",
672
+ " 'score': 0.8482835,\n",
673
+ " 'index': 28,\n",
674
+ " 'word': '##rop',\n",
675
+ " 'start': 133,\n",
676
+ " 'end': 136},\n",
677
+ " {'entity': 'LABEL_2',\n",
678
+ " 'score': 0.54456973,\n",
679
+ " 'index': 29,\n",
680
+ " 'word': '##ionate',\n",
681
+ " 'start': 136,\n",
682
+ " 'end': 142},\n",
683
+ " {'entity': 'LABEL_0',\n",
684
+ " 'score': 0.99813783,\n",
685
+ " 'index': 30,\n",
686
+ " 'word': '>',\n",
687
+ " 'start': 143,\n",
688
+ " 'end': 144},\n",
689
+ " {'entity': 'LABEL_1',\n",
690
+ " 'score': 0.9497072,\n",
691
+ " 'index': 31,\n",
692
+ " 'word': 'tri',\n",
693
+ " 'start': 145,\n",
694
+ " 'end': 148},\n",
695
+ " {'entity': 'LABEL_2',\n",
696
+ " 'score': 0.97295266,\n",
697
+ " 'index': 32,\n",
698
+ " 'word': '##am',\n",
699
+ " 'start': 148,\n",
700
+ " 'end': 150},\n",
701
+ " {'entity': 'LABEL_2',\n",
702
+ " 'score': 0.9831851,\n",
703
+ " 'index': 33,\n",
704
+ " 'word': '##cin',\n",
705
+ " 'start': 150,\n",
706
+ " 'end': 153},\n",
707
+ " {'entity': 'LABEL_2',\n",
708
+ " 'score': 0.9881147,\n",
709
+ " 'index': 34,\n",
710
+ " 'word': '##olone',\n",
711
+ " 'start': 153,\n",
712
+ " 'end': 158},\n",
713
+ " {'entity': 'LABEL_0',\n",
714
+ " 'score': 0.4705981,\n",
715
+ " 'index': 35,\n",
716
+ " 'word': 'acet',\n",
717
+ " 'start': 159,\n",
718
+ " 'end': 163},\n",
719
+ " {'entity': 'LABEL_2',\n",
720
+ " 'score': 0.8985247,\n",
721
+ " 'index': 36,\n",
722
+ " 'word': '##oni',\n",
723
+ " 'start': 163,\n",
724
+ " 'end': 166},\n",
725
+ " {'entity': 'LABEL_2',\n",
726
+ " 'score': 0.81968755,\n",
727
+ " 'index': 37,\n",
728
+ " 'word': '##de',\n",
729
+ " 'start': 166,\n",
730
+ " 'end': 168},\n",
731
+ " {'entity': 'LABEL_0',\n",
732
+ " 'score': 0.99776864,\n",
733
+ " 'index': 38,\n",
734
+ " 'word': '=',\n",
735
+ " 'start': 169,\n",
736
+ " 'end': 170},\n",
737
+ " {'entity': 'LABEL_1',\n",
738
+ " 'score': 0.9653367,\n",
739
+ " 'index': 39,\n",
740
+ " 'word': 'flu',\n",
741
+ " 'start': 171,\n",
742
+ " 'end': 174},\n",
743
+ " {'entity': 'LABEL_2',\n",
744
+ " 'score': 0.9850605,\n",
745
+ " 'index': 40,\n",
746
+ " 'word': '##nis',\n",
747
+ " 'start': 174,\n",
748
+ " 'end': 177},\n",
749
+ " {'entity': 'LABEL_2',\n",
750
+ " 'score': 0.9890426,\n",
751
+ " 'index': 41,\n",
752
+ " 'word': '##oli',\n",
753
+ " 'start': 177,\n",
754
+ " 'end': 180},\n",
755
+ " {'entity': 'LABEL_2',\n",
756
+ " 'score': 0.9816441,\n",
757
+ " 'index': 42,\n",
758
+ " 'word': '##de',\n",
759
+ " 'start': 180,\n",
760
+ " 'end': 182},\n",
761
+ " {'entity': 'LABEL_0',\n",
762
+ " 'score': 0.9975062,\n",
763
+ " 'index': 43,\n",
764
+ " 'word': '.',\n",
765
+ " 'start': 183,\n",
766
+ " 'end': 184}]"
767
+ ]
768
+ },
769
+ "execution_count": 12,
770
+ "metadata": {},
771
+ "output_type": "execute_result"
772
+ }
773
+ ],
774
+ "source": [
775
+ "# something from our validation set\n",
776
+ "effect_ner_model(labeled_dataset[\"test\"][6][\"text\"])"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "markdown",
781
+ "id": "7b67731e",
782
+ "metadata": {},
783
+ "source": [
784
+ "---\n",
785
+ "We try out the first few examples of adverse effects from the Wikipedia page on adverse effects and visualize with the displaCy library:\n",
786
+ "\n",
787
+ "https://en.wikipedia.org/wiki/Adverse_effect#Medications"
788
+ ]
789
+ },
790
+ {
791
+ "cell_type": "code",
792
+ "execution_count": 13,
793
+ "id": "872f67db",
794
+ "metadata": {},
795
+ "outputs": [],
796
+ "source": [
797
+ "def visualize_entities(sentence):\n",
798
+ " tokens = effect_ner_model(sentence)\n",
799
+ " entities = []\n",
800
+ " # ['O', 'B-DRUG', 'I-DRUG', 'B-DISEASE', 'I-DISEASE', 'B-GENE', 'I-GENE']\n",
801
+ " for token in tokens:\n",
802
+ " label = int(token[\"entity\"][-1])\n",
803
+ " if label != 0:\n",
804
+ " token[\"label\"] = label_list[label]\n",
805
+ " entities.append(token)\n",
806
+ " \n",
807
+ " params = [{\"text\": sentence,\n",
808
+ " \"ents\": entities,\n",
809
+ " \"title\": None}]\n",
810
+ " \n",
811
+ " html = displacy.render(params, style=\"ent\", manual=True, options={\n",
812
+ " \"colors\": {\n",
813
+ " \n",
814
+ " \"B-DRUG\": \"#f08080\",\n",
815
+ " \"I-DRUG\": \"#f08080\",\n",
816
+ " \"B-DISEASE\": \"#9bddff\",\n",
817
+ " \"I-DISEASE\": \"#9bddff\",\n",
818
+ " \"B-GENE\": \"#008080\",\n",
819
+ " \"I-GENE\": \"#008080\",\n",
820
+ " },\n",
821
+ " })\n",
822
+ " return html"
823
+ ]
824
+ },
825
+ {
826
+ "cell_type": "code",
827
+ "execution_count": 14,
828
+ "id": "b919e5e5",
829
+ "metadata": {},
830
+ "outputs": [
831
+ {
832
+ "name": "stdout",
833
+ "output_type": "stream",
834
+ "text": [
835
+ "Running on local URL: http://127.0.0.1:7860\n",
836
+ "\n",
837
+ "Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB\n",
838
+ "\n",
839
+ "To create a public link, set `share=True` in `launch()`.\n"
840
+ ]
841
+ },
842
+ {
843
+ "data": {
844
+ "text/html": [
845
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
846
+ ],
847
+ "text/plain": [
848
+ "<IPython.core.display.HTML object>"
849
+ ]
850
+ },
851
+ "metadata": {},
852
+ "output_type": "display_data"
853
+ }
854
+ ],
855
+ "source": [
856
+ "import gradio as gr\n",
857
+ "\n",
858
+ "exampleList = [\n",
859
+ " 'Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost.',\n",
860
+ " 'A randomized Phase III trial demonstrated noninferiority of APF530 500 mg SC ( granisetron 10 mg ) to intravenous palonosetron 0.25 mg in preventing CINV in patients receiving MEC or HEC in acute ( 0 - 24 hours ) and delayed ( 24 - 120 hours ) settings , with activity over 120 hours .',\n",
861
+ " 'What are the known interactions between Aspirin and the COX-1 enzyme?',\n",
862
+ " 'Can you explain the mechanism of action of Metformin and its effect on the AMPK pathway?',\n",
863
+ " 'Are there any genetic variations in the CYP2C9 gene that may influence the response to Warfarin therapy?',\n",
864
+ " 'I am curious about the role of Herceptin in targeting the HER2/neu protein in breast cancer treatment. How does it work?',\n",
865
+ " 'What are the common side effects associated with Lisinopril, an angiotensin-converting enzyme (ACE) inhibitor?',\n",
866
+ " 'Can you explain the significance of the BCR-ABL fusion protein in the context of Imatinib therapy for chronic myeloid leukemia (CML)?',\n",
867
+ " 'How does Ibuprofen affect the COX-2 enzyme compared to COX-1?',\n",
868
+ " 'Are there any recent studies exploring the use of Pembrolizumab as an immune checkpoint inhibitor targeting PD-1?',\n",
869
+ " 'I have heard about the SLC6A4 gene and its association with serotonin reuptake inhibitors (SSRIs) like Fluoxetine.',\n",
870
+ " 'Could you provide insights into the BRAF mutation and its relevance in response to Vemurafenib treatment in melanoma patients?'\n",
871
+ "]\n",
872
+ "\n",
873
+ "footer = \"\"\"\n",
874
+ "LLMGeneLinker uses a domain-specific transformer like SciBERT finetuned on AllenAI drug dataset, BC5CDR disease, NCBI disease, DrugProt and GeneTAG datasets. The resulting SciBERT model performs Named Entity Recognition to tag drug, protein, gene, diseases in input text. Sentence embedding of SciBERT is then fed into BERT \n",
875
+ "This was made during the <a target=\"_blank\" href =https://www.sginnovate.com/event/hackathon-large-language-models-bio> LLMs for Bio Hackathon</a> organised by 4Catalyzer and SGInnovate.\n",
876
+ "<br>\n",
877
+ "Made by Team GeneLink (<a target=\"_blank\" href=https://www.linkedin.com/in/ntkb/>Nicholas</a>, <a target=\"_blank\" href=https://www.linkedin.com/in/yewchong-sim/>Yew Chong</a>, <a target=\"_blank\" href=https://www.linkedin.com/in/lim-ting-wei-021383175/>Ting Wei</a>, <a target=\"_blank\" href=https://www.linkedin.com/in/brendan-lim-ciwen/>Brendan</a>\n",
878
+ "<hr>\n",
879
+ "Note: Performance is noted to be poorer on genes, acronyms, and receptors (named entities that may be targets for drugs or genes)\n",
880
+ "Original notebook adapted from <a target=\"_blank\" href=https://huggingface.co/jsylee/scibert_scivocab_uncased-finetuned-ner>jsylee/scibert_scivocab_uncased-finetuned-ner</a>\n",
881
+ "\"\"\"\n",
882
+ "\n",
883
+ "with gr.Blocks() as demo:\n",
884
+ " gr.Markdown(\"## LLMGeneLinker (LGL)\")\n",
885
+ " gr.Markdown(footer)\n",
886
+ " \n",
887
+ " txt = gr.Textbox(label=\"Input\", lines=2)\n",
888
+ " txt_3 = gr.HTML(label=\"Output\")\n",
889
+ " btn = gr.Button(value=\"Submit\")\n",
890
+ " btn.click(visualize_entities, inputs=txt, outputs=txt_3)\n",
891
+ "\n",
892
+ " gr.Markdown(\"## Text Examples\")\n",
893
+ " gr.Examples(\n",
894
+ " [[x] for x in exampleList],\n",
895
+ " txt,\n",
896
+ " txt_3,\n",
897
+ " visualize_entities,\n",
898
+ " cache_examples=False,\n",
899
+ " run_on_click=True\n",
900
+ " )\n",
901
+ "\n",
902
+ "\n",
903
+ "if __name__ == \"__main__\":\n",
904
+ " demo.launch()"
905
+ ]
906
+ }
907
+ ],
908
+ "metadata": {
909
+ "kernelspec": {
910
+ "display_name": "Python 3 (ipykernel)",
911
+ "language": "python",
912
+ "name": "python3"
913
+ },
914
+ "language_info": {
915
+ "codemirror_mode": {
916
+ "name": "ipython",
917
+ "version": 3
918
+ },
919
+ "file_extension": ".py",
920
+ "mimetype": "text/x-python",
921
+ "name": "python",
922
+ "nbconvert_exporter": "python",
923
+ "pygments_lexer": "ipython3",
924
+ "version": "3.8.0"
925
+ }
926
+ },
927
+ "nbformat": 4,
928
+ "nbformat_minor": 5
929
+ }