Spaces:
Runtime error
Runtime error
Ryan Kim
commited on
Commit
Β·
548d450
1
Parent(s):
fda5a48
final push
Browse files- README.md +35 -1
- src/main.py +1 -6
- src/val.ipynb +1 -0
README.md
CHANGED
@@ -13,6 +13,8 @@ pinned: false
|
|
13 |
|
14 |
## Milestone 4
|
15 |
|
|
|
|
|
16 |
### **Code Breakdown**
|
17 |
|
18 |
The USPTO application is divided into several directories. Overall, the important files are present in the application as such:
|
@@ -24,9 +26,10 @@ The USPTO application is divided into several directories. Overall, the importan
|
|
24 |
- src/
|
25 |
- main.py
|
26 |
- train.ipynb
|
|
|
27 |
````
|
28 |
|
29 |
-
Both `train.json` and `val.json` contain the original USPTO data, sized down to contain only the relevant data from each recorded patent and split between training and validation data. The validation data `val.json` is used in the online USPTO application as a set of pre-set patents that a user can select when using the USPTO patent prediction function.
|
30 |
|
31 |
The primary code back-end is stored in `main.py`, which runs the application on the HuggingFace space UI. The application uses **Streamlit** to render UI elements on the screen. All models run off of Transformers and Tokenizers from **HuggingFace**.
|
32 |
|
@@ -348,6 +351,37 @@ print("=== TRAINING CLAIMS ===")
|
|
348 |
Train(train_claims_loader,upsto_claims_model_path, num_train_epochs=10)
|
349 |
````
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
---
|
352 |
|
353 |
## Milestone 3
|
|
|
13 |
|
14 |
## Milestone 4
|
15 |
|
16 |
+
[![Project Video](https://img.youtube.com/vi/csSGBwIE7nk/0.jpg)](https://youtu.be/csSGBwIE7nk "Project Video")
|
17 |
+
|
18 |
### **Code Breakdown**
|
19 |
|
20 |
The USPTO application is divided into several directories. Overall, the important files are present in the application as such:
|
|
|
26 |
- src/
|
27 |
- main.py
|
28 |
- train.ipynb
|
29 |
+
- val.ipynb
|
30 |
````
|
31 |
|
32 |
+
Both `train.json` and `val.json` contain the original USPTO data, sized down to contain only the relevant data from each recorded patent and split between training and validation data. The validation data `val.json` is used in the online USPTO application as a set of pre-set patents that a user can select when using the USPTO patent prediction function. That, and the `val.ipynb` file was used to validate the model's accuracy.
|
33 |
|
34 |
The primary code back-end is stored in `main.py`, which runs the application on the HuggingFace space UI. The application uses **Streamlit** to render UI elements on the screen. All models run off of Transformers and Tokenizers from **HuggingFace**.
|
35 |
|
|
|
351 |
Train(train_claims_loader,upsto_claims_model_path, num_train_epochs=10)
|
352 |
````
|
353 |
|
354 |
+
### Evaluating the Models
|
355 |
+
|
356 |
+
#### Sentiment Analysis
|
357 |
+
|
358 |
+
There isn't an effective way to validate the sentiment analysis models, as they are publicly available models and it is unknown what data they were explicitly trained on. Therefore, evaluation will rely on anecdotal testing.
|
359 |
+
|
360 |
+
The sentiment models that appear to work the best are the [cardiffnlp/twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) and - [siebert/sentiment-roberta-large-english](https://huggingface.co/siebert/sentiment-roberta-large-english) models, with few caveats. These two models generally perform very well at detecting sentiment in mid to long expressions. However the [siebert/sentiment-roberta-large-english](https://huggingface.co/siebert/sentiment-roberta-large-english) model tends to suffer when expressions are shorter and less complex in lexicon. Even the [cardiffnlp/twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) model suffers from time to time if not enough context has been provided.
|
361 |
+
|
362 |
+
The model that performed the worst is the [finiteautomata/beto-sentiment-analysis](https://huggingface.co/finiteautomata/beto-sentiment-analysis). This model seems to have the worst time trying to interpret meaning from sentences, even with strongly worded language such as "hate". For example, the expression _"I hate you" returns a **NEUTRAL** response with 99.6% confidence, which differs from the [cardiffnlp/twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) and - [siebert/sentiment-roberta-large-english](https://huggingface.co/siebert/sentiment-roberta-large-english) models (**NEGATIVE**: ~96.5% - ~99.9% accuracy respectively). It appears that the [finiteautomata/beto-sentiment-analysis](https://huggingface.co/finiteautomata/beto-sentiment-analysis) gets confused when not enough context is provided. The expression "I hate you because you hurt my family" manages to return a **NEGATIVE** label, but with a mere 87.7% confidence.
|
363 |
+
|
364 |
+
The unique model is the [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) model, which instead gives 6 general emotions as opposed to a binary **NEGATIVE** or **POSITIVE** rating:
|
365 |
+
|
366 |
+
* Sadness
|
367 |
+
* Joy
|
368 |
+
* Love
|
369 |
+
* Anger
|
370 |
+
* Fear
|
371 |
+
* Surprise
|
372 |
+
|
373 |
+
This one offers more nuance to each evaluation, but while the general distinction between "positive" and "negative" emotions is often accurate, the specific emotion applied to the statement is sometimes confusing. For example, the expression "I love you because you saved my family" is labeled as **JOY** instead of **LOVE**.
|
374 |
+
|
375 |
+
Overall, all models are able to perform to some level of success, but the context of each input needs to be made clear in the wording for the models to produce accurate responses. These models seem to suffer when not enough context is provided, sometimes blatantly giving incorrect responses to situations that are too simple.
|
376 |
+
|
377 |
+
#### Patent Acceptance Prediction
|
378 |
+
|
379 |
+
With access to labeled validation data, the fine-tuned models can be ranked in accuracy. Sample code that does so is provided in `src/val.ipynb` and evaluates 1000 random data samples out of 4000+ samples inside of `src/val.ipynb`.
|
380 |
+
|
381 |
+
Overall, both the fine-tuned Abstract model and the Claims model seem to perform very similarly, producing accuracy rates of 72.89% and 72.8% accuracy respectively out of 1000 random samples. The aggregated softmax labeling, assuming equal weighting between the two models, produces 76.2% accuracy. Depending on who you ask, this performance can be discouraging or encouraging.
|
382 |
+
|
383 |
+
The 76% accuracy rating being higher than the individual fine-tuned models is interesting, as it implies that the accuracy between the two models is not consistent; there are times when one of the models might outperform the other in certain situations. If the accuracy remained the same across all 1000 random samples, then the aggregated prediction accuracy rate would be within the same range as the individual fine-tuned models.
|
384 |
+
|
385 |
---
|
386 |
|
387 |
## Milestone 3
|
src/main.py
CHANGED
@@ -33,8 +33,6 @@ class ModelImplementation(object):
|
|
33 |
self.tokenizer = tokenizer_func.from_pretrained(self.tokenizer_model_name)
|
34 |
self.classifier = pipeline_func(model=self.model, tokenizer=self.tokenizer, padding=True, truncation=True, **classifier_args)
|
35 |
self.parser = parser_func
|
36 |
-
|
37 |
-
self.history = []
|
38 |
|
39 |
def predict(self, val):
|
40 |
result = self.classifier(val)
|
@@ -97,10 +95,7 @@ def emotion_model_change():
|
|
97 |
classifier_args={ "task" : "sentiment-analysis" },
|
98 |
placeholders=["@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence."]
|
99 |
)
|
100 |
-
|
101 |
-
if "page" not in st.session_state:
|
102 |
-
st.session_state.page = "home"
|
103 |
-
|
104 |
if "emotion_model_name" not in st.session_state:
|
105 |
st.session_state.emotion_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
106 |
emotion_model_change()
|
|
|
33 |
self.tokenizer = tokenizer_func.from_pretrained(self.tokenizer_model_name)
|
34 |
self.classifier = pipeline_func(model=self.model, tokenizer=self.tokenizer, padding=True, truncation=True, **classifier_args)
|
35 |
self.parser = parser_func
|
|
|
|
|
36 |
|
37 |
def predict(self, val):
|
38 |
result = self.classifier(val)
|
|
|
95 |
classifier_args={ "task" : "sentiment-analysis" },
|
96 |
placeholders=["@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence."]
|
97 |
)
|
98 |
+
|
|
|
|
|
|
|
99 |
if "emotion_model_name" not in st.session_state:
|
100 |
st.session_state.emotion_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
101 |
emotion_model_change()
|
src/val.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPbIO5QK/V8keB7h6h+8Ju2"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":22,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ePuwhQ7QyzUW","executionInfo":{"status":"ok","timestamp":1682571700367,"user_tz":240,"elapsed":29378,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"9c939d4a-7622-4c48-ba58-b83162400692"},"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: datasets in /usr/local/lib/python3.9/dist-packages (2.11.0)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (6.0)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (4.65.0)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from datasets) (1.5.3)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.9/dist-packages (from datasets) (3.2.0)\n","Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.18.0)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from datasets) (1.22.4)\n","Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.14.1)\n","Requirement already satisfied: multiprocess in /usr/local/lib/python3.9/dist-packages (from datasets) (0.70.14)\n","Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.3.6)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (9.0.0)\n","Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (2023.4.0)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (2.27.1)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.9/dist-packages (from datasets) (3.8.4)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (23.1.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.3.3)\n","Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (2.0.12)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (6.0.4)\n","Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (4.0.2)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.9.2)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (3.12.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (3.4)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2022.7.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: streamlit in /usr/local/lib/python3.9/dist-packages (1.21.0)\n","Requirement already satisfied: packaging>=14.1 in /usr/local/lib/python3.9/dist-packages (from streamlit) (23.1)\n","Requirement already satisfied: toml in /usr/local/lib/python3.9/dist-packages (from streamlit) (0.10.2)\n","Requirement already satisfied: tzlocal>=1.1 in /usr/local/lib/python3.9/dist-packages (from streamlit) (4.3)\n","Requirement already satisfied: protobuf<4,>=3.12 in /usr/local/lib/python3.9/dist-packages (from streamlit) (3.20.3)\n","Requirement already satisfied: importlib-metadata>=1.4 in /usr/local/lib/python3.9/dist-packages (from streamlit) (6.6.0)\n","Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (13.3.4)\n","Requirement already satisfied: python-dateutil in /usr/local/lib/python3.9/dist-packages (from streamlit) (2.8.2)\n","Requirement already satisfied: pympler>=0.9 in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.0.1)\n","Requirement already satisfied: pandas<2,>=0.25 in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.5.3)\n","Requirement already satisfied: typing-extensions>=3.10.0.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (4.5.0)\n","Requirement already satisfied: validators>=0.2 in /usr/local/lib/python3.9/dist-packages (from streamlit) (0.20.0)\n","Requirement already satisfied: blinker>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.6.2)\n","Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (8.4.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.22.4)\n","Requirement already satisfied: tornado>=6.0.3 in /usr/local/lib/python3.9/dist-packages (from streamlit) (6.2)\n","Requirement already satisfied: watchdog in /usr/local/lib/python3.9/dist-packages (from streamlit) (3.0.0)\n","Requirement already satisfied: gitpython!=3.1.19 in /usr/local/lib/python3.9/dist-packages (from streamlit) (3.1.31)\n","Requirement already satisfied: pydeck>=0.1.dev5 in /usr/local/lib/python3.9/dist-packages (from streamlit) (0.8.1b0)\n","Requirement already satisfied: requests>=2.4 in /usr/local/lib/python3.9/dist-packages (from streamlit) (2.27.1)\n","Requirement already satisfied: cachetools>=4.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (5.3.0)\n","Requirement already satisfied: altair<5,>=3.2.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (4.2.2)\n","Requirement already satisfied: pyarrow>=4.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (9.0.0)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (8.1.3)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (3.1.2)\n","Requirement already satisfied: toolz in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (0.12.0)\n","Requirement already satisfied: entrypoints in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (0.4)\n","Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (4.3.3)\n","Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.9/dist-packages (from gitpython!=3.1.19->streamlit) (4.0.10)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=1.4->streamlit) (3.15.0)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas<2,>=0.25->streamlit) (2022.7.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil->streamlit) (1.16.0)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (2.0.12)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (2022.12.7)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (1.26.15)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (3.4)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.9/dist-packages (from rich>=10.11.0->streamlit) (2.14.0)\n","Requirement already satisfied: markdown-it-py<3.0.0,>=2.2.0 in /usr/local/lib/python3.9/dist-packages (from rich>=10.11.0->streamlit) (2.2.0)\n","Requirement already satisfied: pytz-deprecation-shim in /usr/local/lib/python3.9/dist-packages (from tzlocal>=1.1->streamlit) (0.1.0.post0)\n","Requirement already satisfied: decorator>=3.4.0 in /usr/local/lib/python3.9/dist-packages (from validators>=0.2->streamlit) (4.4.2)\n","Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.9/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.19->streamlit) (5.0.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->altair<5,>=3.2.0->streamlit) (2.1.2)\n","Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.9/dist-packages (from jsonschema>=3.0->altair<5,>=3.2.0->streamlit) (23.1.0)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.9/dist-packages (from jsonschema>=3.0->altair<5,>=3.2.0->streamlit) (0.19.3)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.9/dist-packages (from markdown-it-py<3.0.0,>=2.2.0->rich>=10.11.0->streamlit) (0.1.2)\n","Requirement already satisfied: tzdata in /usr/local/lib/python3.9/dist-packages (from pytz-deprecation-shim->tzlocal>=1.1->streamlit) (2023.3)\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: transformers in /usr/local/lib/python3.9/dist-packages (4.28.1)\n","Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers) (3.12.0)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.14.1)\n","Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from transformers) (2.27.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.9/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (2023.4.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (1.26.15)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2.0.12)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2022.12.7)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (3.4)\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (4.65.0)\n"]}],"source":["!pip install datasets\n","!pip install streamlit\n","!pip install transformers\n","!pip install tqdm"]},{"cell_type":"code","source":["from datasets import load_dataset\n","import pandas as pd\n","import numpy as np\n","import os\n","import json\n","import torch\n","import sys\n","from tqdm import tqdm\n","\n","import streamlit as st\n","from transformers import TextClassificationPipeline, pipeline\n","from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertTokenizerFast, DistilBertForSequenceClassification"],"metadata":{"id":"xqhKMsNVzBtY","executionInfo":{"status":"ok","timestamp":1682571793784,"user_tz":240,"elapsed":3,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}}},"execution_count":27,"outputs":[]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/gdrive')"],"metadata":{"id":"4E_xZUUwzGJm","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1682570070672,"user_tz":240,"elapsed":23530,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"a6fbb01a-caeb-4dc5-bef1-837c5dce202f"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive\n"]}]},{"cell_type":"code","source":["abstract_model = TextClassificationPipeline(\n"," model = DistilBertForSequenceClassification.from_pretrained('rk2546/uspto-patents-abstracts'),\n"," tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased'),\n"," padding = True, \n"," truncation = True,\n"," return_all_scores = True\n",")\n","\n","claim_model = TextClassificationPipeline(\n"," model = DistilBertForSequenceClassification.from_pretrained('rk2546/uspto-patents-claims'),\n"," tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased'),\n"," padding = True, \n"," truncation = True,\n"," return_all_scores = True\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Mj3hQGRU90bA","executionInfo":{"status":"ok","timestamp":1682573368942,"user_tz":240,"elapsed":7417,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"05cc8f93-1c72-4880-ae76-8d132d500c5f"},"execution_count":39,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.9/dist-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n"," warnings.warn(\n"]}]},{"cell_type":"code","source":["path_to_valData = \"./gdrive/MyDrive/AI [Spring 2023]/cs-gy-6613-project-rk2546/val.json\"\n","f = open(path_to_valData)\n","valData = json.load(f)\n","f.close()"],"metadata":{"id":"0oimA5tO9c1G","executionInfo":{"status":"ok","timestamp":1682570188049,"user_tz":240,"elapsed":1507,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# We track the successes of abstracts, claims, and both combined\n","abstract_successes = 0\n","claim_successes = 0\n","aggregate_successes = 0\n","total_num = len(valData['labels'])\n","\n","# By default, we weigh the claims more highly than abstracts\n","claim_weight = 0.5\n","abstract_weight = 0.5\n","\n","# To randomize the data, we generate random indices \n","index_perms = np.random.permutation(total_num)\n","labels = []\n","abstracts = []\n","claims = []\n","# We generate up to 500 samples to validate against\n","new_total_num = min(1000,len(index_perms))\n","for i in range(new_total_num):\n"," labels.append(valData['labels'][index_perms[i]])\n"," abstracts.append(valData['abstracts'][index_perms[i]])\n"," claims.append(valData['claims'][index_perms[i]])\n","\n","# Now we validate\n","for i in tqdm(range(new_total_num)):\n"," label = labels[i]\n"," abstract = abstracts[i]\n"," claim = claims[i]\n","\n"," abstract_response = abstract_model(abstract)[0]\n"," claim_response = claim_model(claim)[0]\n"," aggregate_response = [\n"," {'label':'REJECTED','score':abstract_response[0]['score']*abstract_weight + claim_response[0]['score']*claim_weight},\n"," {'label':'ACCEPTED','score':abstract_response[1]['score']*abstract_weight + claim_response[1]['score']*claim_weight}\n"," ]\n","\n"," abstract_sorted = sorted(abstract_response, key=lambda d: d['score'], reverse=True) \n"," claim_sorted = sorted(claim_response, key=lambda d: d['score'], reverse=True)\n"," aggregate_sorted = sorted(aggregate_response, key=lambda d: d['score'], reverse=True) \n","\n"," if abstract_sorted[0]['label'] == 'LABEL_1' and label == 1:\n"," abstract_successes += 1\n"," elif abstract_sorted[0]['label'] == 'LABEL_0' and label == 0:\n"," abstract_successes += 1\n"," \n"," if claim_sorted[0]['label'] == 'LABEL_1' and label == 1:\n"," claim_successes += 1\n"," elif claim_sorted[0]['label'] == 'LABEL_0' and label == 0:\n"," claim_successes += 1\n"," \n"," if aggregate_sorted[0]['label'] == 'ACCEPTED' and label == 1:\n"," aggregate_successes += 1\n"," elif aggregate_sorted[0]['label'] == 'REJECTED' and label == 0:\n"," aggregate_successes += 1\n","\n"," # At 10% intervals, we print the current results\n"," if i > 0 and i % (new_total_num * 0.1) == 0:\n"," print(f\"\\nAbs: {abstract_successes}/{i} | Cl: {claim_successes}/{i} | Agg: {aggregate_successes}/{i}\")\n","\n","# Calculate final accuracy\n","abstract_accuracy = abstract_successes / new_total_num\n","claim_accuracy = claim_successes / new_total_num\n","aggregate_accuracy = aggregate_successes / new_total_num\n","\n","# Display accuracy\n","print(\"\\n\")\n","print(f\"Abstract Model Accuracy: {abstract_accuracy * 100}%\")\n","print(f\"Claim Model Accuracy: {claim_accuracy * 100}%\")\n","print(f\"Aggregated Model Accuracy: {aggregate_accuracy * 100}%\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"FLE-9qlw9qW7","executionInfo":{"status":"ok","timestamp":1682577092672,"user_tz":240,"elapsed":1356393,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"fe0fb5ed-b075-4e4d-c616-6dbac5148a75"},"execution_count":48,"outputs":[{"output_type":"stream","name":"stderr","text":[" 10%|β | 101/1000 [02:25<22:03, 1.47s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 70/100 | Cl: 73/100 | Agg: 73/100\n"]},{"output_type":"stream","name":"stderr","text":[" 20%|ββ | 201/1000 [04:38<21:25, 1.61s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 148/200 | Cl: 155/200 | Agg: 155/200\n"]},{"output_type":"stream","name":"stderr","text":[" 30%|βββ | 301/1000 [06:53<13:59, 1.20s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 220/300 | Cl: 224/300 | Agg: 234/300\n"]},{"output_type":"stream","name":"stderr","text":[" 40%|ββββ | 401/1000 [09:08<11:16, 1.13s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 295/400 | Cl: 293/400 | Agg: 308/400\n"]},{"output_type":"stream","name":"stderr","text":[" 50%|βββββ | 501/1000 [11:24<10:34, 1.27s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 362/500 | Cl: 365/500 | Agg: 383/500\n"]},{"output_type":"stream","name":"stderr","text":[" 60%|ββββββ | 601/1000 [13:37<10:44, 1.61s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 443/600 | Cl: 440/600 | Agg: 462/600\n"]},{"output_type":"stream","name":"stderr","text":[" 70%|βββββββ | 701/1000 [15:54<06:52, 1.38s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 523/700 | Cl: 517/700 | Agg: 546/700\n"]},{"output_type":"stream","name":"stderr","text":[" 80%|ββββββββ | 801/1000 [18:07<03:42, 1.12s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 601/800 | Cl: 591/800 | Agg: 626/800\n"]},{"output_type":"stream","name":"stderr","text":[" 90%|βββββββββ | 901/1000 [20:24<01:56, 1.18s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 670/900 | Cl: 666/900 | Agg: 703/900\n"]},{"output_type":"stream","name":"stderr","text":["100%|ββββββββββ| 1000/1000 [22:36<00:00, 1.36s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","\n","Abstract Model Accuracy: 72.89999999999999%\n","Claim Model Accuracy: 72.8%\n","Aggregated Model Accuracy: 76.2%\n"]},{"output_type":"stream","name":"stderr","text":["\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"Enwp7rw___5t"},"execution_count":null,"outputs":[]}]}
|