Spaces:
Runtime error
Runtime error
add
Browse files- .gitattributes +1 -0
- __init__.py +0 -0
- app.py +21 -0
- configs/best_model.pt +3 -0
- configs/config.yaml +30 -0
- datasets/mix_dataset_2022_08_30.csv +3 -0
- example.ipynb +122 -0
- extraction.py +22 -0
- src/clustering.py +48 -0
- src/datas.py +133 -0
- src/model.py +225 -0
- src/run.py +279 -0
- src/utils.py +299 -0
- train.py +12 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
datasets/mix_dataset_2022_08_30.csv filter=lfs diff=lfs merge=lfs -text
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from extraction import paragraph_extract
|
3 |
+
|
4 |
+
def predict(paragraphs, positions):
|
5 |
+
paragraphs = [paragraphs]
|
6 |
+
positions = [positions]
|
7 |
+
return extractor(paragraphs, positions)[0]
|
8 |
+
|
9 |
+
extractor = paragraph_extract().extract
|
10 |
+
|
11 |
+
example_paragraph = 'The W/Zr/HfO2 /TiN structure was fabricated following the scheme shown in the inset of Fig. 1(a). A 5-nm-thick HfO2 layer was deposited on a TiN substrate by an atomic layer deposition system. After HfO2 film deposition, thermal annealing was performed under NH3 at 700โC in order to achieve optimum concentration of oxygen vacancies [10]. Then, the 3-nm-thick Zr top electrode and a 50-nm-thick W capping layer were deposited by RF magnetron sputtering system. The size of the upper electrode was 10ร10ย ฮผm2 . The electrical measurements were performed by an Agilent B1500A semiconductor device analyzer, equipped with two pulse generator modules WGFMU (Waveform Generator and Fast Measurement Unit). The coaxial cables with a 50-ฮฉ resistance and less than 10 cm in length were used to reduce the parasitic effects.'
|
12 |
+
example_position = 4
|
13 |
+
|
14 |
+
demo = gr.Interface(fn=predict, inputs=[gr.inputs.Textbox(lines=3, label="Paragraphs", placeholder='Text Here...'),
|
15 |
+
gr.inputs.Number(label="Positions")],
|
16 |
+
outputs="text",
|
17 |
+
title="ReRAM Paragraph Classification", allow_flagging=False,
|
18 |
+
examples=[[example_paragraph, example_position]],
|
19 |
+
)
|
20 |
+
|
21 |
+
demo.launch(share=True)
|
configs/best_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5dfa79fd27e2577a6dfedb39d439cee4e4d823789fe844127a930f179b04592d
|
3 |
+
size 442225641
|
configs/config.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train:
|
2 |
+
epochs: 20
|
3 |
+
batch_size: 16
|
4 |
+
lr: 1.0E-5
|
5 |
+
dropout: 0.1
|
6 |
+
data_cut: null
|
7 |
+
early_stop_count: 5
|
8 |
+
|
9 |
+
wandb:
|
10 |
+
wandb_log: True
|
11 |
+
wandb_project: 'paragraph_classification'
|
12 |
+
wandb_group: 'model_test'
|
13 |
+
wandb_memo: 'scibert'
|
14 |
+
wandb_name: 'scibert'
|
15 |
+
|
16 |
+
model:
|
17 |
+
model_name: 'allenai/scibert_scivocab_cased'
|
18 |
+
data_file: './datasets/mix_dataset_2022_08_30.csv'
|
19 |
+
max_length: 512
|
20 |
+
random_state: 1000
|
21 |
+
task_type: 'scalar'
|
22 |
+
freeze_layers: null
|
23 |
+
num_classifier: 1
|
24 |
+
num_pos_emb_layer: 1
|
25 |
+
sentence_piece: False
|
26 |
+
bertsum: False
|
27 |
+
|
28 |
+
extract:
|
29 |
+
selected_model: 'configs/best_model.pt'
|
30 |
+
batch_size: 16
|
datasets/mix_dataset_2022_08_30.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd4fe999f4289ed8840dd9600958da1ad8c11caa29fca8eb6ce89a6e42df391a
|
3 |
+
size 13741487
|
example.ipynb
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/home/rnwnsgud1234/0.Files/anaconda3/envs/nlp/lib/python3.9/site-packages/gradio/inputs.py:26: UserWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
|
13 |
+
" warnings.warn(\n",
|
14 |
+
"/home/rnwnsgud1234/0.Files/anaconda3/envs/nlp/lib/python3.9/site-packages/gradio/deprecation.py:40: UserWarning: `optional` parameter is deprecated, and it has no effect\n",
|
15 |
+
" warnings.warn(value)\n",
|
16 |
+
"/home/rnwnsgud1234/0.Files/anaconda3/envs/nlp/lib/python3.9/site-packages/gradio/deprecation.py:40: UserWarning: `numeric` parameter is deprecated, and it has no effect\n",
|
17 |
+
" warnings.warn(value)\n",
|
18 |
+
"/home/rnwnsgud1234/0.Files/anaconda3/envs/nlp/lib/python3.9/site-packages/gradio/inputs.py:58: UserWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
|
19 |
+
" warnings.warn(\n",
|
20 |
+
"/home/rnwnsgud1234/0.Files/anaconda3/envs/nlp/lib/python3.9/site-packages/gradio/interface.py:359: UserWarning: The `allow_flagging` parameter in `Interface` nowtakes a string value ('auto', 'manual', or 'never'), not a boolean. Setting parameter to: 'never'.\n",
|
21 |
+
" warnings.warn(\n"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"name": "stdout",
|
26 |
+
"output_type": "stream",
|
27 |
+
"text": [
|
28 |
+
"Running on local URL: http://127.0.0.1:7862\n",
|
29 |
+
"Running on public URL: https://bd6b9acba15cf888.gradio.app\n",
|
30 |
+
"\n",
|
31 |
+
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"data": {
|
36 |
+
"text/html": [
|
37 |
+
"<div><iframe src=\"https://bd6b9acba15cf888.gradio.app\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
38 |
+
],
|
39 |
+
"text/plain": [
|
40 |
+
"<IPython.core.display.HTML object>"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
"metadata": {},
|
44 |
+
"output_type": "display_data"
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"data": {
|
48 |
+
"text/plain": []
|
49 |
+
},
|
50 |
+
"execution_count": 1,
|
51 |
+
"metadata": {},
|
52 |
+
"output_type": "execute_result"
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"name": "stderr",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"100%|โโโโโโโโโโ| 1/1 [00:01<00:00, 1.37s/it]\n",
|
59 |
+
"100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 67.94it/s]\n"
|
60 |
+
]
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"source": [
|
64 |
+
"import gradio as gr\n",
|
65 |
+
"from extraction import paragraph_extract\n",
|
66 |
+
"\n",
|
67 |
+
"def predict(paragraphs, positions):\n",
|
68 |
+
" paragraphs = [paragraphs]\n",
|
69 |
+
" positions = [positions]\n",
|
70 |
+
" return extractor(paragraphs, positions)[0]\n",
|
71 |
+
"\n",
|
72 |
+
"extractor = paragraph_extract().extract\n",
|
73 |
+
"\n",
|
74 |
+
"example_paragraph = 'The W/Zr/HfO2 /TiN structure was fabricated following the scheme shown in the inset of Fig. 1(a). A 5-nm-thick HfO2 layer was deposited on a TiN substrate by an atomic layer deposition system. After HfO2 film deposition, thermal annealing was performed under NH3 at 700โC in order to achieve optimum concentration of oxygen vacancies [10]. Then, the 3-nm-thick Zr top electrode and a 50-nm-thick W capping layer were deposited by RF magnetron sputtering system. The size of the upper electrode was 10ร10ย ฮผm2 . The electrical measurements were performed by an Agilent B1500A semiconductor device analyzer, equipped with two pulse generator modules WGFMU (Waveform Generator and Fast Measurement Unit). The coaxial cables with a 50-ฮฉ resistance and less than 10 cm in length were used to reduce the parasitic effects.'\n",
|
75 |
+
"example_position = 4\n",
|
76 |
+
"\n",
|
77 |
+
"demo = gr.Interface(fn=predict, inputs=[gr.inputs.Textbox(lines=3, label=\"Paragraphs\", placeholder='Text Here...'), \n",
|
78 |
+
" gr.inputs.Number(label=\"Positions\")], \n",
|
79 |
+
" outputs=\"text\", \n",
|
80 |
+
" title=\"ReRAM Paragraph Classification\", allow_flagging=False,\n",
|
81 |
+
" examples=[[example_paragraph, example_position]],\n",
|
82 |
+
" )\n",
|
83 |
+
"\n",
|
84 |
+
"demo.launch(share=True)"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": null,
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": []
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"metadata": {
|
96 |
+
"kernelspec": {
|
97 |
+
"display_name": "Python 3.9.15 ('nlp')",
|
98 |
+
"language": "python",
|
99 |
+
"name": "python3"
|
100 |
+
},
|
101 |
+
"language_info": {
|
102 |
+
"codemirror_mode": {
|
103 |
+
"name": "ipython",
|
104 |
+
"version": 3
|
105 |
+
},
|
106 |
+
"file_extension": ".py",
|
107 |
+
"mimetype": "text/x-python",
|
108 |
+
"name": "python",
|
109 |
+
"nbconvert_exporter": "python",
|
110 |
+
"pygments_lexer": "ipython3",
|
111 |
+
"version": "3.9.15"
|
112 |
+
},
|
113 |
+
"orig_nbformat": 4,
|
114 |
+
"vscode": {
|
115 |
+
"interpreter": {
|
116 |
+
"hash": "a0944428a9b48e048108e25849b0259875c53ceed2bb9cda9ef2b8036da8c8e0"
|
117 |
+
}
|
118 |
+
}
|
119 |
+
},
|
120 |
+
"nbformat": 4,
|
121 |
+
"nbformat_minor": 2
|
122 |
+
}
|
extraction.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from .src.run import NLP_classification
|
3 |
+
except:
|
4 |
+
from src.run import NLP_classification
|
5 |
+
import yaml
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
class paragraph_extract:
|
10 |
+
def __init__(self):
|
11 |
+
config_file = 'configs/config.yaml'
|
12 |
+
config_file = os.path.join(os.path.dirname(__file__), config_file)
|
13 |
+
self.config = yaml.load(open(config_file), Loader=yaml.FullLoader)
|
14 |
+
self.config['extract']['selected_model'] = os.path.join(os.path.dirname(__file__), self.config['extract']['selected_model'])
|
15 |
+
|
16 |
+
self.runner = NLP_classification(**self.config['model'])
|
17 |
+
|
18 |
+
def extract(self, paragraphs, positions):
|
19 |
+
|
20 |
+
labels = self.runner.label_extraction(paragraphs, positions, **self.config['extract'])
|
21 |
+
|
22 |
+
return labels
|
src/clustering.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import seaborn as sns
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
import umap
|
6 |
+
|
7 |
+
def dim_reduction(target_embeddings, umap_dim=2, n_neighbors=15, min_dist=0.1):
|
8 |
+
"""
|
9 |
+
Dimension reduction using UMAP.
|
10 |
+
"""
|
11 |
+
reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=umap_dim, min_dist=min_dist, metric='cosine', random_state=500)
|
12 |
+
embeddings = reducer.fit_transform(target_embeddings)
|
13 |
+
return embeddings
|
14 |
+
|
15 |
+
|
16 |
+
def clustering_plot(target_label, embeddings, label_trues, model_preds=None, umap_dim=2, n_neighbors=15, min_dist=0.1):
|
17 |
+
"""
|
18 |
+
Plot the clustering results.
|
19 |
+
"""
|
20 |
+
label_dict = {0:'Abstract', 1:'Introduction', 2:'Main', 3:'Methods', 4:'Summary', 5:'Captions'}
|
21 |
+
|
22 |
+
target_index = np.where(label_trues == target_label)[0]
|
23 |
+
|
24 |
+
trues = label_trues[target_index]
|
25 |
+
embeddings = embeddings[target_index]
|
26 |
+
|
27 |
+
embeddings = dim_reduction(embeddings, umap_dim=umap_dim, n_neighbors=n_neighbors, min_dist=min_dist)
|
28 |
+
|
29 |
+
df = pd.DataFrame(embeddings, columns=['x', 'y'])
|
30 |
+
df['true'] = trues
|
31 |
+
df['true'] = df['true'].map(label_dict)
|
32 |
+
if model_preds is not None:
|
33 |
+
df['pred'] = model_preds[target_index]
|
34 |
+
df['pred'] = df['pred'].map(label_dict)
|
35 |
+
|
36 |
+
sns.scatterplot(x='x', y='y', hue='true', data=df, palette='Set2')
|
37 |
+
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
38 |
+
plt.show()
|
39 |
+
|
40 |
+
if model_preds is not None:
|
41 |
+
sns.scatterplot(x='x', y='y', hue='pred', data=df, palette='Set2')
|
42 |
+
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
43 |
+
plt.show()
|
44 |
+
|
45 |
+
return df
|
46 |
+
|
47 |
+
|
48 |
+
|
src/datas.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import pickle
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import ast
|
9 |
+
from sklearn.utils import shuffle
|
10 |
+
import random
|
11 |
+
from spacy.lang.en import English
|
12 |
+
from .utils import sentencepiece
|
13 |
+
|
14 |
+
def make_dataset(csv_file, tokenizer, max_length=512, padding=None, random_state=1000, data_cut=None, sentence_piece=True):
|
15 |
+
''' data load '''
|
16 |
+
''' 1๊ธฐ+2๊ธฐ ๋ฐ์ดํฐ '''
|
17 |
+
#data = csv_file
|
18 |
+
#total_data = pd.read_csv(data)
|
19 |
+
|
20 |
+
''' ์ฌ์ ์ดํ์ด ์ค ๋ฐ์ดํฐ '''
|
21 |
+
total_data = pd.read_csv(csv_file)
|
22 |
+
total_data.columns=['paragraph', 'category', 'position', 'portion']
|
23 |
+
label_dict = {'Abstract':0, 'Introduction':1, 'Main':2, 'Methods':3, 'Summary':4, 'Captions':5}
|
24 |
+
total_data['label'] = total_data.category.replace(label_dict)
|
25 |
+
|
26 |
+
if not data_cut is None:
|
27 |
+
total_data = total_data.iloc[:data_cut,:]
|
28 |
+
|
29 |
+
total_text = total_data['paragraph'].to_list()
|
30 |
+
total_label = total_data['label'].to_list()
|
31 |
+
total_position = total_data['position'].to_list()
|
32 |
+
total_portion = total_data['portion'].to_list()
|
33 |
+
|
34 |
+
''' type error ๋ฐฉ์ง '''
|
35 |
+
if type(total_label[0]) == str:
|
36 |
+
total_label = [ast.literal_eval(l) for l in total_label]
|
37 |
+
|
38 |
+
if type(total_label[0]) == int:
|
39 |
+
total_label = np.eye(6)[total_label].tolist()
|
40 |
+
|
41 |
+
train_text, val_text, train_labels, val_labels, train_position, val_position, train_portion, val_portion = train_test_split(total_text, total_label, total_position, total_portion, test_size=0.2, random_state=random_state, stratify=total_label)
|
42 |
+
|
43 |
+
''' data๋ค tokenizing '''
|
44 |
+
if not sentence_piece:
|
45 |
+
train_encodings= tokenizer.batch_encode_plus(train_text, truncation=True, return_token_type_ids=True, max_length=max_length, add_special_tokens=True, return_attention_mask=True, padding='max_length')
|
46 |
+
val_encodings = tokenizer.batch_encode_plus(val_text, truncation=True, return_token_type_ids=True, max_length=max_length, add_special_tokens=True, return_attention_mask=True, padding='max_length')
|
47 |
+
else:
|
48 |
+
nlp = English()
|
49 |
+
nlp.add_pipe('sentencizer')
|
50 |
+
train_encodings = sentencepiece(train_text, nlp, tokenizer, max_length=max_length)
|
51 |
+
val_encodings = sentencepiece(val_text, nlp, tokenizer, max_length=max_length)
|
52 |
+
|
53 |
+
''' token tensor ํ '''
|
54 |
+
train_encodings = {key: torch.tensor(val) for key, val in train_encodings.items()}
|
55 |
+
val_encodings = {key: torch.tensor(val) for key, val in val_encodings.items()}
|
56 |
+
|
57 |
+
''' labels tensor ํ '''
|
58 |
+
train_labels_ = {}
|
59 |
+
train_labels_['label_onehot'] = torch.tensor(train_labels, dtype=torch.float)
|
60 |
+
train_labels_['label'] = torch.tensor([t.index(1) for t in train_labels], dtype=torch.int)
|
61 |
+
train_labels = train_labels_
|
62 |
+
|
63 |
+
val_labels_ = {}
|
64 |
+
val_labels_['label_onehot'] = torch.tensor(val_labels, dtype=torch.float)
|
65 |
+
val_labels_['label'] = torch.tensor([t.index(1) for t in val_labels], dtype=torch.long)
|
66 |
+
val_labels = val_labels_
|
67 |
+
|
68 |
+
''' position tensor ํ '''
|
69 |
+
train_positions_ = {}
|
70 |
+
train_positions_['position'] = torch.tensor(train_position, dtype=torch.float)
|
71 |
+
train_positions_['portion'] = torch.tensor(train_portion, dtype=torch.float)
|
72 |
+
train_positions = train_positions_
|
73 |
+
|
74 |
+
val_positions_ = {}
|
75 |
+
val_positions_['position'] = torch.tensor(val_position, dtype=torch.float)
|
76 |
+
val_positions_['portion'] = torch.tensor(val_portion, dtype=torch.float)
|
77 |
+
val_positions = val_positions_
|
78 |
+
|
79 |
+
''' dataset class ์์ฑ '''
|
80 |
+
class CustomDataset(torch.utils.data.Dataset):
|
81 |
+
def __init__(self, encodings, labels, texts, positions):
|
82 |
+
self.encodings = encodings
|
83 |
+
self.labels = labels
|
84 |
+
self.texts = texts
|
85 |
+
self.positions = positions
|
86 |
+
|
87 |
+
def __getitem__(self, idx):
|
88 |
+
item = {key: val[idx] for key, val in self.encodings.items()}
|
89 |
+
item['text'] = self.texts[idx]
|
90 |
+
# scalar version
|
91 |
+
item['label'] = self.labels['label'][idx]
|
92 |
+
# one-hot version
|
93 |
+
item['label_onehot'] = self.labels['label_onehot'][idx]
|
94 |
+
# position
|
95 |
+
item['position'] = self.positions['position'][idx]
|
96 |
+
#portion
|
97 |
+
item['portion'] = self.positions['portion'][idx]
|
98 |
+
return item
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.labels['label_onehot'])
|
102 |
+
|
103 |
+
''' train์ ์ํ format์ผ๋ก data๋ค ๋ณํ '''
|
104 |
+
train_dataset = CustomDataset(train_encodings, train_labels, train_text, train_positions)
|
105 |
+
val_dataset = CustomDataset(val_encodings, val_labels, val_text, val_positions)
|
106 |
+
|
107 |
+
return train_dataset, val_dataset
|
108 |
+
|
109 |
+
|
110 |
+
def make_extract_dataset(paragraphs, positions, tokenizer, max_length):
|
111 |
+
encodings = tokenizer.batch_encode_plus(paragraphs, truncation=True, return_token_type_ids=True, max_length=max_length, add_special_tokens=True, return_attention_mask=True, padding='max_length', return_tensors='pt')
|
112 |
+
positions_ = {}
|
113 |
+
positions_['position'] = torch.tensor(positions, dtype=torch.float)
|
114 |
+
positions = positions_
|
115 |
+
|
116 |
+
class CustomDataset(torch.utils.data.Dataset):
|
117 |
+
def __init__(self, encodings, positions):
|
118 |
+
self.encodings = encodings
|
119 |
+
self.positions = positions
|
120 |
+
|
121 |
+
def __getitem__(self, idx):
|
122 |
+
item = {key: val[idx] for key, val in self.encodings.items()}
|
123 |
+
# position
|
124 |
+
item['position'] = self.positions['position'][idx]
|
125 |
+
return item
|
126 |
+
|
127 |
+
def __len__(self):
|
128 |
+
return len(self.encodings['input_ids'])
|
129 |
+
|
130 |
+
dataset = CustomDataset(encodings, positions)
|
131 |
+
return dataset
|
132 |
+
|
133 |
+
|
src/model.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
4 |
+
from torch import nn
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
6 |
+
from torch.nn import Identity
|
7 |
+
from transformers.activations import get_activation
|
8 |
+
import numpy as np
|
9 |
+
#from torch_scatter import scatter_add
|
10 |
+
from .utils import input_check, pos_encoding
|
11 |
+
|
12 |
+
class classification_model(torch.nn.Module):
|
13 |
+
def __init__(self, pretrained_model, config, num_classifier=1, num_pos_emb_layer=1, bertsum=False, device=None):
|
14 |
+
super(classification_model, self).__init__()
|
15 |
+
|
16 |
+
self.config = config
|
17 |
+
self.num_labels = config.num_labels
|
18 |
+
self.pretrained_model = pretrained_model
|
19 |
+
if hasattr(config, 'd_model'):
|
20 |
+
self.pretrained_hidden = config.d_model
|
21 |
+
elif hasattr(config, 'hidden_size'):
|
22 |
+
self.pretrained_hidden = config.hidden_size
|
23 |
+
self.sequence_summary = SequenceSummary(config)
|
24 |
+
self.bertsum = bertsum
|
25 |
+
self.device = device
|
26 |
+
self.return_hidden = False
|
27 |
+
self.return_hidden_pretrained = False
|
28 |
+
|
29 |
+
if self.bertsum:
|
30 |
+
#self.pooling_1 = GATpooling(self.pretrained_hidden)
|
31 |
+
#self.fnn_1 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden)
|
32 |
+
self.pooling_2 = GATpooling(self.pretrained_hidden, self.device)
|
33 |
+
self.fnn_2 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden)
|
34 |
+
|
35 |
+
self.pos_emb_layer = nn.Sequential(*[nn.Linear(self.pretrained_hidden, self.pretrained_hidden) for _ in range(num_pos_emb_layer)])
|
36 |
+
|
37 |
+
dim_list = np.linspace(self.pretrained_hidden, config.num_labels, num_classifier+1, dtype=np.int32)
|
38 |
+
#dim_list = np.linspace(768, config.num_labels, num_classifier+1, dtype=np.int32)
|
39 |
+
self.classifiers = nn.ModuleList()
|
40 |
+
for c in range(num_classifier):
|
41 |
+
self.classifiers.append(nn.Linear(dim_list[c], dim_list[c+1]))
|
42 |
+
|
43 |
+
def forward(self, inputs):
|
44 |
+
hidden_states = None
|
45 |
+
input_ids = inputs['input_ids']
|
46 |
+
token_type_ids = inputs['token_type_ids']
|
47 |
+
attention_mask = inputs['attention_mask']
|
48 |
+
position = inputs['position']
|
49 |
+
transformer_inputs = input_check({'input_ids':input_ids, 'token_type_ids':token_type_ids, 'attention_mask':attention_mask}, self.pretrained_model)
|
50 |
+
|
51 |
+
pretrianed_output = self.pretrained_model(**transformer_inputs)
|
52 |
+
output = pretrianed_output[0]
|
53 |
+
|
54 |
+
if self.return_hidden_pretrained and self.return_hidden:
|
55 |
+
hidden_states = pretrianed_output[1]
|
56 |
+
if self.bertsum:
|
57 |
+
output = scatter_add(output, inputs['sentence_batch'], dim=-2)
|
58 |
+
#output = self.pooling_1(output, inputs['sentence_batch'])
|
59 |
+
#output = self.fnn_1(output)
|
60 |
+
output = self.pooling_2(output)
|
61 |
+
output = output.squeeze()
|
62 |
+
output = self.fnn_2(output)
|
63 |
+
else:
|
64 |
+
output = self.sequence_summary(output)
|
65 |
+
|
66 |
+
# paragraph positional encoding vector add
|
67 |
+
pos_emb = pos_encoding(position, self.pretrained_hidden).to(self.device, dtype=torch.float)
|
68 |
+
output = torch.add(output,pos_emb)
|
69 |
+
output = self.pos_emb_layer(output)
|
70 |
+
|
71 |
+
if self.return_hidden and not self.return_hidden_pretrained:
|
72 |
+
hidden_states = output
|
73 |
+
for layer in self.classifiers:
|
74 |
+
output = layer(output)
|
75 |
+
|
76 |
+
logits = output
|
77 |
+
|
78 |
+
if 'labels' in inputs.keys():
|
79 |
+
loss = self.classification_loss_f(inputs, logits)
|
80 |
+
else:
|
81 |
+
loss = None
|
82 |
+
|
83 |
+
return loss, output, hidden_states
|
84 |
+
|
85 |
+
def classification_loss_f(self, inputs, logits):
|
86 |
+
labels=inputs['labels']
|
87 |
+
loss=None
|
88 |
+
|
89 |
+
if labels is not None:
|
90 |
+
if self.config.problem_type is None:
|
91 |
+
if self.num_labels == 1:
|
92 |
+
self.config.problem_type = "regression"
|
93 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
94 |
+
self.config.problem_type = "single_label_classification"
|
95 |
+
else:
|
96 |
+
self.config.problem_type = "multi_label_classification"
|
97 |
+
|
98 |
+
if self.config.problem_type == "regression":
|
99 |
+
loss_fct = MSELoss()
|
100 |
+
if self.num_labels == 1:
|
101 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
102 |
+
else:
|
103 |
+
loss = loss_fct(logits, labels)
|
104 |
+
elif self.config.problem_type == "single_label_classification":
|
105 |
+
loss_fct = CrossEntropyLoss()
|
106 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
107 |
+
elif self.config.problem_type == "multi_label_classification":
|
108 |
+
loss_fct = BCEWithLogitsLoss()
|
109 |
+
loss = loss_fct(logits, labels)
|
110 |
+
return loss
|
111 |
+
|
112 |
+
|
113 |
+
class GATpooling(nn.Module):
|
114 |
+
def __init__(self, hidden_size, device=None):
|
115 |
+
super(GATpooling, self).__init__()
|
116 |
+
self.gate_nn = nn.Linear(hidden_size, 1)
|
117 |
+
self.device = device
|
118 |
+
|
119 |
+
def forward(self, x, batch=None):
|
120 |
+
if batch==None:
|
121 |
+
batch = torch.zeros(x.shape[-2], dtype=torch.long).to(self.device)
|
122 |
+
gate = self.gate_nn(x)
|
123 |
+
gate = F.softmax(gate, dim=-1)
|
124 |
+
out = scatter_add(gate*x, batch, dim=-2)
|
125 |
+
return out
|
126 |
+
|
127 |
+
|
128 |
+
class SequenceSummary(nn.Module):
|
129 |
+
r"""
|
130 |
+
Compute a single vector summary of a sequence hidden states.
|
131 |
+
Args:
|
132 |
+
config ([`PretrainedConfig`]):
|
133 |
+
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
134 |
+
config class of your model for the default values it uses):
|
135 |
+
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
136 |
+
- `"last"` -- Take the last token hidden state (like XLNet)
|
137 |
+
- `"first"` -- Take the first token hidden state (like Bert)
|
138 |
+
- `"mean"` -- Take the mean of all tokens hidden states
|
139 |
+
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
140 |
+
- `"attn"` -- Not implemented now, use multi-head attention
|
141 |
+
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
142 |
+
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
143 |
+
(otherwise to `config.hidden_size`).
|
144 |
+
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
145 |
+
another string or `None` will add no activation.
|
146 |
+
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
147 |
+
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, config):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
self.summary_type = getattr(config, "summary_type", "mean")
|
154 |
+
if self.summary_type == "attn":
|
155 |
+
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
156 |
+
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
157 |
+
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
158 |
+
raise NotImplementedError
|
159 |
+
|
160 |
+
self.summary = Identity()
|
161 |
+
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
162 |
+
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
163 |
+
num_classes = config.num_labels
|
164 |
+
else:
|
165 |
+
num_classes = config.hidden_size
|
166 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
167 |
+
|
168 |
+
activation_string = getattr(config, "summary_activation", None)
|
169 |
+
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
|
170 |
+
|
171 |
+
self.first_dropout = Identity()
|
172 |
+
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
173 |
+
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
174 |
+
|
175 |
+
self.last_dropout = Identity()
|
176 |
+
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
177 |
+
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
178 |
+
|
179 |
+
def forward(
|
180 |
+
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
181 |
+
) -> torch.FloatTensor:
|
182 |
+
"""
|
183 |
+
Compute a single vector summary of a sequence hidden states.
|
184 |
+
Args:
|
185 |
+
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
186 |
+
The hidden states of the last layer.
|
187 |
+
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
188 |
+
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
189 |
+
Returns:
|
190 |
+
`torch.FloatTensor`: The summary of the sequence hidden states.
|
191 |
+
"""
|
192 |
+
if self.summary_type == "last":
|
193 |
+
output = hidden_states[:, -1]
|
194 |
+
elif self.summary_type == "first":
|
195 |
+
output = hidden_states[:, 0]
|
196 |
+
elif self.summary_type == "mean":
|
197 |
+
output = hidden_states.mean(dim=1)
|
198 |
+
elif self.summary_type == "cls_index":
|
199 |
+
if cls_index is None:
|
200 |
+
cls_index = torch.full_like(
|
201 |
+
hidden_states[..., :1, :],
|
202 |
+
hidden_states.shape[-2] - 1,
|
203 |
+
dtype=torch.long,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
207 |
+
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
208 |
+
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
209 |
+
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
210 |
+
elif self.summary_type == "attn":
|
211 |
+
raise NotImplementedError
|
212 |
+
|
213 |
+
output = self.first_dropout(output)
|
214 |
+
output = self.summary(output)
|
215 |
+
output = self.activation(output)
|
216 |
+
output = self.last_dropout(output)
|
217 |
+
|
218 |
+
return output
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
src/run.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
from tqdm import tqdm
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import ast
|
9 |
+
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
|
10 |
+
from transformers import EarlyStoppingCallback
|
11 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForSequenceClassification
|
12 |
+
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
|
13 |
+
from sklearn.utils import shuffle
|
14 |
+
from transformers import get_cosine_schedule_with_warmup
|
15 |
+
from torch.nn import functional as F
|
16 |
+
import random
|
17 |
+
import pandas as pd
|
18 |
+
from .datas import make_dataset, make_extract_dataset
|
19 |
+
from .utils import set_seed, accuracy_per_class, compute_metrics, model_eval, checkpoint_save, EarlyStopping, model_freeze, get_hidden
|
20 |
+
from .model import classification_model
|
21 |
+
from transformers import BigBirdTokenizer
|
22 |
+
import transformers
|
23 |
+
|
24 |
+
class NLP_classification():
|
25 |
+
def __init__(self, model_name=None, data_file=None, max_length=None, random_state=1000, task_type='onehot', freeze_layers=None, num_classifier=1, num_pos_emb_layer=1, gpu_num=0, sentence_piece=True, bertsum=False):
|
26 |
+
self.model_name = model_name
|
27 |
+
self.data_file = data_file
|
28 |
+
self.max_length = max_length
|
29 |
+
self.random_state = random_state
|
30 |
+
self.task_type = task_type
|
31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
|
32 |
+
if model_name == 'google/bigbird-roberta-base':
|
33 |
+
self.tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
|
34 |
+
self.config = AutoConfig.from_pretrained(model_name, num_labels=6)
|
35 |
+
|
36 |
+
#self.pretrained_model = AutoModelForSequenceClassification.from_config(self.config)
|
37 |
+
self.pretrained_model = AutoModel.from_config(self.config)
|
38 |
+
self.freeze_layers=freeze_layers
|
39 |
+
self.num_classifier=num_classifier
|
40 |
+
self.num_pos_emb_layer=num_pos_emb_layer
|
41 |
+
self.gpu_num=gpu_num
|
42 |
+
self.sentence_piece=sentence_piece
|
43 |
+
self.bertsum=bertsum
|
44 |
+
if self.max_length is None:
|
45 |
+
self.padding='longest'
|
46 |
+
else:
|
47 |
+
self.padding='max_length'
|
48 |
+
|
49 |
+
|
50 |
+
def training(self, epochs=50, batch_size=4, lr=1e-5, dropout=0.1, data_cut=None, early_stop_count=10,
|
51 |
+
wandb_log=False, wandb_project=None, wandb_group=None, wandb_name=None, wandb_memo=None):
|
52 |
+
os.environ["CUDA_VISIBLE_DEVICES"]= "{0}".format(int(self.gpu_num))
|
53 |
+
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
54 |
+
device = torch.device('cuda:{0}'.format(int(self.gpu_num)))
|
55 |
+
torch.cuda.set_device(device)
|
56 |
+
set_seed(self.random_state)
|
57 |
+
torch.set_num_threads(10)
|
58 |
+
|
59 |
+
if wandb_log is True:
|
60 |
+
import wandb
|
61 |
+
wandb.init(project=wandb_project, reinit=True, group=wandb_group, notes=wandb_memo)
|
62 |
+
wandb.run.name = wandb_name
|
63 |
+
wandb.run.save()
|
64 |
+
parameters = wandb.config
|
65 |
+
parameters.lr = lr
|
66 |
+
parameters.batch_size = batch_size
|
67 |
+
parameters.dropout = dropout
|
68 |
+
parameters.train_num = data_cut
|
69 |
+
parameters.max_length = self.max_length
|
70 |
+
parameters.model_name = self.model_name
|
71 |
+
parameters.task_type = self.task_type
|
72 |
+
|
73 |
+
'''data loading'''
|
74 |
+
train_dataset, val_dataset = make_dataset(csv_file=self.data_file, tokenizer=self.tokenizer, max_length=self.max_length, padding=self.padding, random_state=self.random_state, data_cut=data_cut, sentence_piece=self.sentence_piece)
|
75 |
+
|
76 |
+
'''loader making'''
|
77 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=RandomSampler(train_dataset))
|
78 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=SequentialSampler(val_dataset))
|
79 |
+
|
80 |
+
''' model load '''
|
81 |
+
model=classification_model(self.pretrained_model, self.config, num_classifier=self.num_classifier, num_pos_emb_layer=self.num_pos_emb_layer, bertsum=self.bertsum, device=device)
|
82 |
+
model=model_freeze(model, self.freeze_layers)
|
83 |
+
model.to(device)
|
84 |
+
|
85 |
+
''' running setting '''
|
86 |
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
87 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr, eps=1e-8)
|
88 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=(len(train_loader)*epochs))
|
89 |
+
early_stopping = EarlyStopping(patience = early_stop_count, verbose = True)
|
90 |
+
|
91 |
+
''' running '''
|
92 |
+
best_epoch = None
|
93 |
+
best_val_f1 = None
|
94 |
+
|
95 |
+
for epoch in range(epochs):
|
96 |
+
model.train()
|
97 |
+
loss_all = 0
|
98 |
+
step = 0
|
99 |
+
|
100 |
+
for data in tqdm(train_loader):
|
101 |
+
input_ids=data['input_ids'].to(device, dtype=torch.long)
|
102 |
+
mask = data['attention_mask'].to(device, dtype=torch.long)
|
103 |
+
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
|
104 |
+
if self.task_type=='onehot':
|
105 |
+
targets=data['label_onehot'].to(device, dtype=torch.float)
|
106 |
+
elif self.task_type=='scalar':
|
107 |
+
targets=data['label'].to(device, dtype=torch.long)
|
108 |
+
position = data['position']
|
109 |
+
inputs = {'input_ids': input_ids, 'attention_mask': mask, 'token_type_ids': token_type_ids,
|
110 |
+
'labels': targets, 'position': position}
|
111 |
+
if self.sentence_piece:
|
112 |
+
sentence_batch = data['sentence_batch'].to(device, dtype=torch.long)
|
113 |
+
inputs = {'input_ids': input_ids, 'attention_mask': mask, 'token_type_ids': token_type_ids,
|
114 |
+
'labels': targets, 'sentence_batch': sentence_batch, 'position': position}
|
115 |
+
|
116 |
+
outputs = model(inputs)
|
117 |
+
output = outputs[1]
|
118 |
+
loss = outputs[0]
|
119 |
+
|
120 |
+
optimizer.zero_grad()
|
121 |
+
#loss=loss_fn(output, targets)
|
122 |
+
loss_all += loss.item()
|
123 |
+
|
124 |
+
loss.backward()
|
125 |
+
optimizer.step()
|
126 |
+
scheduler.step()
|
127 |
+
#print(optimizer.param_groups[0]['lr'])
|
128 |
+
|
129 |
+
train_loss = loss_all/len(train_loader)
|
130 |
+
val_loss, val_acc, val_precision, val_recall, val_f1 = model_eval(model, device, val_loader, task_type=self.task_type, sentence_piece=self.sentence_piece)
|
131 |
+
|
132 |
+
if wandb_log is True:
|
133 |
+
wandb.log({'train_loss':train_loss, 'val_loss':val_loss, 'val_acc':val_acc,
|
134 |
+
'val_precision':val_precision, 'val_recall':val_recall, 'val_f1':val_f1})
|
135 |
+
|
136 |
+
if best_val_f1 is None or val_f1 >= best_val_f1:
|
137 |
+
best_epoch = epoch+1
|
138 |
+
best_val_f1 = val_f1
|
139 |
+
checkpoint_save(model, val_f1, wandb_name=wandb_name)
|
140 |
+
|
141 |
+
print('Epoch: {:03d}, Train Loss: {:.7f}, Val Loss: {:.7f}, Val Acc: {:.7f}, Val Precision: {:.7f}, Val Recall: {:.7f}, Val F1: {:.7f} '.format(epoch+1, train_loss, val_loss, val_acc, val_precision, val_recall, val_f1))
|
142 |
+
|
143 |
+
early_stopping(val_f1)
|
144 |
+
if early_stopping.early_stop:
|
145 |
+
print("Early stopping")
|
146 |
+
break
|
147 |
+
|
148 |
+
wandb.finish()
|
149 |
+
|
150 |
+
|
151 |
+
def prediction(self, selected_model=None, batch_size=8):
|
152 |
+
os.environ["CUDA_VISIBLE_DEVICES"]= "{0}".format(int(self.gpu_num))
|
153 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
154 |
+
set_seed(self.random_state)
|
155 |
+
torch.set_num_threads(10)
|
156 |
+
task_type=self.task_type
|
157 |
+
|
158 |
+
'''data loading'''
|
159 |
+
train_dataset, val_dataset = make_dataset(csv_file=self.data_file, tokenizer=self.tokenizer, max_length=self.max_length, padding=self.padding, random_state=self.random_state, data_cut=None, sentence_piece=self.sentence_piece)
|
160 |
+
|
161 |
+
'''loader making'''
|
162 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=RandomSampler(train_dataset))
|
163 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=SequentialSampler(val_dataset))
|
164 |
+
|
165 |
+
''' model load '''
|
166 |
+
model=classification_model(self.pretrained_model, self.config, num_classifier=self.num_classifier, num_pos_emb_layer=self.num_pos_emb_layer, bertsum=self.bertsum, device=device)
|
167 |
+
model.load_state_dict(torch.load(selected_model))
|
168 |
+
model.to(device)
|
169 |
+
|
170 |
+
''' prediction '''
|
171 |
+
print('start trainset prediction')
|
172 |
+
train_results = model_eval(model, device, train_loader, task_type=self.task_type, return_values=True, sentence_piece=self.sentence_piece)
|
173 |
+
print('start evalset prediction')
|
174 |
+
eval_results = model_eval(model, device, val_loader, task_type=self.task_type, return_values=True, sentence_piece=self.sentence_piece)
|
175 |
+
|
176 |
+
print('train result: acc:{0} | precision:{1} | recall:{2} | f1:{3}'.format(train_results[1], train_results[2], train_results[3], train_results[4]))
|
177 |
+
print('eval result: acc:{0} | precision:{1} | recall:{2} | f1:{3}'.format(eval_results[1], eval_results[2], eval_results[3], eval_results[4]))
|
178 |
+
|
179 |
+
total_text = train_results[7] + eval_results[7]
|
180 |
+
total_out = train_results[6] + eval_results[6]
|
181 |
+
total_target = train_results[5] + eval_results[5]
|
182 |
+
|
183 |
+
if self.task_type == 'onehot':
|
184 |
+
total_out = [i.argmax() for i in total_out]
|
185 |
+
total_target = [i.argmax() for i in total_target]
|
186 |
+
|
187 |
+
total_data = {'text':total_text, 'label':total_target, 'predict':total_out}
|
188 |
+
total_df = pd.DataFrame(total_data)
|
189 |
+
|
190 |
+
''' result return '''
|
191 |
+
return total_df
|
192 |
+
|
193 |
+
|
194 |
+
def get_embedding(self, selected_model=None, batch_size=8, return_hidden=True, return_hidden_pretrained=False):
|
195 |
+
os.environ["CUDA_VISIBLE_DEVICES"]= "{0}".format(int(self.gpu_num))
|
196 |
+
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
197 |
+
device = torch.device('cuda:{0}'.format(int(self.gpu_num)))
|
198 |
+
torch.cuda.set_device(device)
|
199 |
+
set_seed(self.random_state)
|
200 |
+
torch.set_num_threads(10)
|
201 |
+
task_type=self.task_type
|
202 |
+
|
203 |
+
'''data loading'''
|
204 |
+
train_dataset, val_dataset = make_dataset(csv_file=self.data_file, tokenizer=self.tokenizer, max_length=self.max_length, padding=self.padding, random_state=self.random_state, data_cut=None, sentence_piece=self.sentence_piece)
|
205 |
+
|
206 |
+
'''loader making'''
|
207 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=RandomSampler(train_dataset))
|
208 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=SequentialSampler(val_dataset))
|
209 |
+
|
210 |
+
''' model load '''
|
211 |
+
model=classification_model(self.pretrained_model, self.config, num_classifier=self.num_classifier, num_pos_emb_layer=self.num_pos_emb_layer, bertsum=self.bertsum, device=device)
|
212 |
+
model.return_hidden = return_hidden
|
213 |
+
model.return_hidden_pretrained = return_hidden_pretrained
|
214 |
+
if selected_model is not None:
|
215 |
+
model.load_state_dict(torch.load(selected_model))
|
216 |
+
model.to(device)
|
217 |
+
|
218 |
+
''' get hidden '''
|
219 |
+
print('start make hidden states (trainset)')
|
220 |
+
train_hiddens, train_targets = get_hidden(model, device, train_loader, task_type=self.task_type, sentence_piece=self.sentence_piece)
|
221 |
+
print('start evalset prediction (eval set)')
|
222 |
+
eval_hiddens, eval_targets = get_hidden(model, device, val_loader, task_type=self.task_type, sentence_piece=self.sentence_piece)
|
223 |
+
total_hiddens = np.array(train_hiddens + eval_hiddens)
|
224 |
+
total_targets = np.array(train_targets + eval_targets)
|
225 |
+
|
226 |
+
|
227 |
+
return total_hiddens, total_targets
|
228 |
+
|
229 |
+
|
230 |
+
def label_extraction(self, paragraphs, positions, selected_model=None, batch_size=16):
|
231 |
+
label_dict = {'Abstract':0, 'Introduction':1, 'Main':2, 'Methods':3, 'Summary':4, 'Captions':5}
|
232 |
+
#os.environ["CUDA_VISIBLE_DEVICES"]= "{0}".format(int(self.gpu_num))
|
233 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
234 |
+
set_seed(self.random_state)
|
235 |
+
torch.set_num_threads(10)
|
236 |
+
|
237 |
+
''' data to list '''
|
238 |
+
is_list = True
|
239 |
+
if not isinstance(paragraphs, list):
|
240 |
+
paragraphs = [paragraphs]
|
241 |
+
is_list = False
|
242 |
+
if not isinstance(positions, list):
|
243 |
+
positions = [positions]
|
244 |
+
is_list = False
|
245 |
+
|
246 |
+
'''data encoding'''
|
247 |
+
dataset = make_extract_dataset(paragraphs, positions, tokenizer=self.tokenizer, max_length=self.max_length)
|
248 |
+
|
249 |
+
'''loader making'''
|
250 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
251 |
+
|
252 |
+
''' model load '''
|
253 |
+
model=classification_model(self.pretrained_model, self.config, num_classifier=self.num_classifier, num_pos_emb_layer=self.num_pos_emb_layer, bertsum=self.bertsum, device=device)
|
254 |
+
model.load_state_dict(torch.load(selected_model))
|
255 |
+
model.to(device)
|
256 |
+
|
257 |
+
''' prediction '''
|
258 |
+
model.eval()
|
259 |
+
predicts = []
|
260 |
+
with torch.no_grad():
|
261 |
+
for batch in tqdm(data_loader):
|
262 |
+
inputs = {}
|
263 |
+
inputs['input_ids'] = batch['input_ids'].to(device)
|
264 |
+
inputs['attention_mask'] = batch['attention_mask'].to(device)
|
265 |
+
inputs['token_type_ids'] = batch['token_type_ids'].to(device)
|
266 |
+
inputs['position'] = batch['position']
|
267 |
+
outputs = model(inputs)
|
268 |
+
logits = outputs[1]
|
269 |
+
logits = logits.detach().cpu().numpy()
|
270 |
+
logits = logits.argmax(axis=1).flatten()
|
271 |
+
logits = logits.tolist()
|
272 |
+
predicts.extend(logits)
|
273 |
+
predicts = [list(label_dict.keys())[list(label_dict.values()).index(i)] for i in predicts]
|
274 |
+
|
275 |
+
if not is_list:
|
276 |
+
predicts = predicts[0]
|
277 |
+
return predicts
|
278 |
+
|
279 |
+
|
src/utils.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from doctest import DocFileCase
|
2 |
+
from tqdm import tqdm
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
|
6 |
+
from sklearn.utils import shuffle
|
7 |
+
import random
|
8 |
+
import datetime as dt
|
9 |
+
import os
|
10 |
+
from glob import glob
|
11 |
+
from spacy.lang.en import English
|
12 |
+
import inspect
|
13 |
+
|
14 |
+
def checkpoint_save(model, val_loss, checkpoint_dir=None, wandb_name=None):
|
15 |
+
if checkpoint_dir is None:
|
16 |
+
checkpoint_dir = './save_model'
|
17 |
+
if not os.path.isdir(checkpoint_dir):
|
18 |
+
os.mkdir(checkpoint_dir)
|
19 |
+
x = dt.datetime.now()
|
20 |
+
y = x.year
|
21 |
+
m = x.month
|
22 |
+
d = x.day
|
23 |
+
|
24 |
+
if wandb_name is None:
|
25 |
+
wandb_name = "testing"
|
26 |
+
|
27 |
+
torch.save(model.state_dict(), "./save_model/{}_{}_{}_{:.4f}_{}.pt".format(y, m, d, val_loss, wandb_name))
|
28 |
+
|
29 |
+
#saved_dict_list = glob(os.path.join(checkpoint_dir, '*.pt'))
|
30 |
+
saved_dict_list = glob(os.path.join(checkpoint_dir, '{}_{}_{}_*_{}.pt'.format(y,m,d,wandb_name)))
|
31 |
+
|
32 |
+
|
33 |
+
val_loss_list = np.array([float(os.path.basename(loss).split("_")[3]) for loss in saved_dict_list])
|
34 |
+
saved_dict_list.pop(val_loss_list.argmax())
|
35 |
+
|
36 |
+
for i in saved_dict_list:
|
37 |
+
os.remove(i)
|
38 |
+
|
39 |
+
|
40 |
+
def set_seed(seed):
|
41 |
+
torch.backends.cudnn.deterministic = True
|
42 |
+
torch.backends.cudnn.benchmark = False
|
43 |
+
torch.manual_seed(seed)
|
44 |
+
torch.cuda.manual_seed_all(seed)
|
45 |
+
np.random.seed(seed)
|
46 |
+
random.seed(seed)
|
47 |
+
|
48 |
+
def accuracy_per_class(preds, labels):
|
49 |
+
label_dict = {'Abstract':0, 'Intro':1, 'Main':2, 'Method':3, 'Summary':4, 'Caption':5}
|
50 |
+
label_dict_inverse = {v: k for k, v in label_dict.items()}
|
51 |
+
|
52 |
+
class_list = []
|
53 |
+
acc_list = []
|
54 |
+
for label in list(label_dict.values()):
|
55 |
+
y_preds = preds[labels==label]
|
56 |
+
y_true = labels[labels==label]
|
57 |
+
class_list.append(label_dict_inverse[label])
|
58 |
+
acc_list.append("{0}/{1}".format(len(y_preds[y_preds==label]), len(y_true)))
|
59 |
+
|
60 |
+
print("{:10} {:10} {:10} {:10} {:10} {:10}".format(class_list[0], class_list[1], class_list[2], class_list[3], class_list[4], class_list[5]))
|
61 |
+
print("{:10} {:10} {:10} {:10} {:10} {:10}".format(acc_list[0], acc_list[1], acc_list[2], acc_list[3], acc_list[4], acc_list[5]))
|
62 |
+
|
63 |
+
|
64 |
+
def compute_metrics(output, target, task_type='onehot'):
|
65 |
+
if task_type=='onehot':
|
66 |
+
pred=np.argmax(output, axis=1).flatten()
|
67 |
+
labels=np.argmax(target, axis=1).flatten()
|
68 |
+
elif task_type=='scalar':
|
69 |
+
pred=np.argmax(output, axis=1).flatten()
|
70 |
+
labels=np.array(target).flatten()
|
71 |
+
accuracy = accuracy_score(y_true=labels, y_pred=pred)
|
72 |
+
recall = recall_score(y_true=labels, y_pred=pred, average='macro')
|
73 |
+
precision = precision_score(y_true=labels, y_pred=pred, average='macro', zero_division=0)
|
74 |
+
f1 = f1_score(y_true=labels, y_pred=pred, average='macro')
|
75 |
+
|
76 |
+
accuracy_per_class(pred, labels)
|
77 |
+
|
78 |
+
return [accuracy, precision, recall, f1]
|
79 |
+
|
80 |
+
def input_check(input_dict, model):
|
81 |
+
model_inputs = inspect.signature(model.forward).parameters.keys()
|
82 |
+
inputs = {}
|
83 |
+
for key, val in input_dict.items():
|
84 |
+
if key in model_inputs:
|
85 |
+
inputs[key] = val
|
86 |
+
return inputs
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
def model_eval(model, device, loader, task_type='onehot', return_values=False, sentence_piece=False):
|
91 |
+
model.eval()
|
92 |
+
error = 0
|
93 |
+
accuracy = 0
|
94 |
+
precision = 0
|
95 |
+
recall = 0
|
96 |
+
f1 = 0
|
97 |
+
eval_targets=[]
|
98 |
+
eval_outputs=[]
|
99 |
+
eval_texts=[]
|
100 |
+
with torch.no_grad():
|
101 |
+
for data in tqdm(loader):
|
102 |
+
eval_texts.extend(data['text'])
|
103 |
+
input_ids=data['input_ids'].to(device, dtype=torch.long)
|
104 |
+
mask = data['attention_mask'].to(device, dtype=torch.long)
|
105 |
+
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
|
106 |
+
if task_type=='onehot':
|
107 |
+
targets=data['label_onehot'].to(device, dtype=torch.float)
|
108 |
+
elif task_type=='scalar':
|
109 |
+
targets=data['label'].to(device, dtype=torch.long)
|
110 |
+
position = data['position']
|
111 |
+
inputs = {'input_ids': input_ids, 'attention_mask': mask, 'token_type_ids': token_type_ids,
|
112 |
+
'labels': targets, 'position': position}
|
113 |
+
if sentence_piece:
|
114 |
+
sentence_batch = data['sentence_batch'].to(device, dtype=torch.long)
|
115 |
+
inputs = {'input_ids': input_ids, 'attention_mask': mask, 'token_type_ids': token_type_ids,
|
116 |
+
'labels': targets, 'sentence_batch': sentence_batch, 'position': position}
|
117 |
+
outputs = model(inputs)
|
118 |
+
output = outputs[1]
|
119 |
+
loss = outputs[0]
|
120 |
+
#loss=loss_fn(output, targets)
|
121 |
+
error+=loss
|
122 |
+
#output = torch.sigmoid(output)
|
123 |
+
eval_targets.extend(targets.detach().cpu().numpy())
|
124 |
+
eval_outputs.extend(output.detach().cpu().numpy())
|
125 |
+
|
126 |
+
error = error / len(loader)
|
127 |
+
accuracy, precision, recall, f1 = compute_metrics(eval_outputs, eval_targets, task_type=task_type)
|
128 |
+
|
129 |
+
if return_values:
|
130 |
+
return [error, accuracy, precision, recall, f1, eval_targets, eval_outputs, eval_texts]
|
131 |
+
else:
|
132 |
+
return [error, accuracy, precision, recall, f1]
|
133 |
+
|
134 |
+
|
135 |
+
def get_hidden(model, device, loader, task_type='onehot', sentence_piece=False):
|
136 |
+
model.eval()
|
137 |
+
total_hidden_state = []
|
138 |
+
total_targets=[]
|
139 |
+
with torch.no_grad():
|
140 |
+
for data in tqdm(loader):
|
141 |
+
input_ids=data['input_ids'].to(device, dtype=torch.long)
|
142 |
+
mask = data['attention_mask'].to(device, dtype=torch.long)
|
143 |
+
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
|
144 |
+
if task_type=='onehot':
|
145 |
+
targets=data['label_onehot'].to(device, dtype=torch.float)
|
146 |
+
elif task_type=='scalar':
|
147 |
+
targets=data['label'].to(device, dtype=torch.long)
|
148 |
+
position = data['position']
|
149 |
+
inputs = {'input_ids': input_ids, 'attention_mask': mask, 'token_type_ids': token_type_ids,
|
150 |
+
'labels': targets, 'position': position}
|
151 |
+
if sentence_piece:
|
152 |
+
sentence_batch = data['sentence_batch'].to(device, dtype=torch.long)
|
153 |
+
inputs = {'input_ids': input_ids, 'attention_mask': mask, 'token_type_ids': token_type_ids,
|
154 |
+
'labels': targets, 'sentence_batch': sentence_batch, 'position': position}
|
155 |
+
outputs = model(inputs)
|
156 |
+
hidden_state = outputs[2]
|
157 |
+
total_hidden_state.extend(hidden_state.detach().cpu().numpy())
|
158 |
+
total_targets.extend(targets.detach().cpu().numpy())
|
159 |
+
return total_hidden_state, total_targets
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
def sentencepiece(paragraph_list, spacy_nlp, tokenizer, max_length=512):
|
164 |
+
# ํ์ฌ token type ids๊ฐ tokenizer์์ ์์ฑํ๋ ๋ฐ์ดํฐ๊ฐ ์๋ ๋ด๊ฐ ์์์ ์ผ๋ก 0, 1๋ก๋ง ๋ฃ๋๋ก ํด๋์์, XLNET ๊ฐ์๊ฑด CLS๊ฐ 2๋ก ๋๋ ๊ฒฝ์ฐ ๊ฐ์ด ์ด ๊ท์น์ ๋ฒ์ด๋๋ ๊ฒฝ์ฐ๊ฐ ์์ด์ ๋์ค์ ๋ฌธ์ ๋๋ฉด ์์ ํ์
|
165 |
+
encode_datas = {'input_ids': [], 'token_type_ids': [], 'attention_mask': [], 'sentence_batch': []}
|
166 |
+
for paragraph in paragraph_list:
|
167 |
+
doc = spacy_nlp(paragraph)
|
168 |
+
sentence_encode = [sent.text for sent in doc.sents]
|
169 |
+
sentence_encode = tokenizer.batch_encode_plus(sentence_encode, max_length=max_length, padding='max_length', return_attention_mask=True, return_token_type_ids=True)
|
170 |
+
|
171 |
+
sentence_list = sentence_encode['input_ids']
|
172 |
+
mask_list = sentence_encode['attention_mask']
|
173 |
+
pad_token = None
|
174 |
+
pad_position = None
|
175 |
+
total_sentence = torch.tensor([], dtype=torch.int)
|
176 |
+
token_type_ids = []
|
177 |
+
s_batch = []
|
178 |
+
|
179 |
+
for n, s in enumerate(sentence_list):
|
180 |
+
if pad_token is None:
|
181 |
+
pad_token = s[mask_list[n].index(0)]
|
182 |
+
if pad_position is None:
|
183 |
+
if s[0] == pad_token:
|
184 |
+
pad_position = 'start'
|
185 |
+
else:
|
186 |
+
pad_position = 'end'
|
187 |
+
|
188 |
+
s=torch.tensor(s, dtype=torch.int)
|
189 |
+
s = s[s!=pad_token]
|
190 |
+
total_length = len(total_sentence) + len(s)
|
191 |
+
if total_length > max_length:
|
192 |
+
break
|
193 |
+
total_sentence = torch.concat([total_sentence, s])
|
194 |
+
token_type_ids = token_type_ids + [n%2]*len(s)
|
195 |
+
s_batch = s_batch + [n]*len(s)
|
196 |
+
|
197 |
+
total_sentence = total_sentence.tolist()
|
198 |
+
pad_length = max_length - len(total_sentence)
|
199 |
+
attention_mask = [1]*len(total_sentence)
|
200 |
+
if pad_position == 'end':
|
201 |
+
total_sentence = total_sentence + [pad_token]*pad_length
|
202 |
+
attention_mask = attention_mask + [0]*pad_length
|
203 |
+
s_batch = s_batch + [max(s_batch)+1]*pad_length
|
204 |
+
if n%2 == 0:
|
205 |
+
token_type_ids = token_type_ids + [1]*pad_length
|
206 |
+
else:
|
207 |
+
token_type_ids = token_type_ids + [0]*pad_length
|
208 |
+
|
209 |
+
elif pad_position == 'start':
|
210 |
+
total_sentence = [pad_token]*pad_length + total_sentence
|
211 |
+
attention_mask = [0]*pad_length + attention_mask
|
212 |
+
s_batch = [max(s_batch)+1]*pad_length + s_batch
|
213 |
+
if n%2 == 0:
|
214 |
+
token_type_ids = [0]*pad_length + token_type_ids
|
215 |
+
else:
|
216 |
+
token_type_ids = [1]*pad_length + token_type_ids
|
217 |
+
|
218 |
+
encode_datas['input_ids'].append(total_sentence)
|
219 |
+
encode_datas['token_type_ids'].append(token_type_ids)
|
220 |
+
encode_datas['attention_mask'].append(attention_mask)
|
221 |
+
encode_datas['sentence_batch'].append(s_batch)
|
222 |
+
|
223 |
+
return encode_datas
|
224 |
+
|
225 |
+
|
226 |
+
class EarlyStopping:
|
227 |
+
"""์ฃผ์ด์ง patience ์ดํ๋ก validation loss๊ฐ ๊ฐ์ ๋์ง ์์ผ๋ฉด ํ์ต์ ์กฐ๊ธฐ ์ค์ง"""
|
228 |
+
def __init__(self, patience=7, verbose=False, delta=0):
|
229 |
+
"""
|
230 |
+
Args:
|
231 |
+
patience (int): validation loss๊ฐ ๊ฐ์ ๋ ํ ๊ธฐ๋ค๋ฆฌ๋ ๊ธฐ๊ฐ
|
232 |
+
Default: 7
|
233 |
+
verbose (bool): True์ผ ๊ฒฝ์ฐ ๊ฐ validation loss์ ๊ฐ์ ์ฌํญ ๋ฉ์ธ์ง ์ถ๋ ฅ
|
234 |
+
Default: False
|
235 |
+
delta (float): ๊ฐ์ ๋์๋ค๊ณ ์ธ์ ๋๋ monitered quantity์ ์ต์ ๋ณํ
|
236 |
+
Default: 0
|
237 |
+
"""
|
238 |
+
self.patience = patience
|
239 |
+
self.verbose = verbose
|
240 |
+
self.counter = 0
|
241 |
+
self.best_score = None
|
242 |
+
self.early_stop = False
|
243 |
+
self.f1_score_max = 0.
|
244 |
+
self.delta = delta
|
245 |
+
|
246 |
+
def __call__(self, f1_score):
|
247 |
+
|
248 |
+
score = -f1_score
|
249 |
+
|
250 |
+
if self.best_score is None:
|
251 |
+
self.best_score = score
|
252 |
+
self.save_checkpoint(f1_score)
|
253 |
+
elif score > self.best_score + self.delta:
|
254 |
+
self.counter += 1
|
255 |
+
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
256 |
+
if self.counter >= self.patience:
|
257 |
+
self.early_stop = True
|
258 |
+
else:
|
259 |
+
self.best_score = score
|
260 |
+
self.save_checkpoint(f1_score)
|
261 |
+
self.counter = 0
|
262 |
+
|
263 |
+
def save_checkpoint(self, f1_score):
|
264 |
+
'''validation loss๊ฐ ๊ฐ์ํ๋ฉด ๊ฐ์๋ฅผ ์ถ๋ ฅํ๋ค.'''
|
265 |
+
if self.verbose:
|
266 |
+
print(f'F1 score increase ({self.f1_score_max:.6f} --> {f1_score:.6f}). ')
|
267 |
+
self.f1_score_max = f1_score
|
268 |
+
|
269 |
+
|
270 |
+
def model_freeze(model, freeze_layers=None):
|
271 |
+
if freeze_layers == 0:
|
272 |
+
return model
|
273 |
+
|
274 |
+
if freeze_layers is not None:
|
275 |
+
for param in model.pretrained_model.base_model.word_embedding.parameters():
|
276 |
+
param.requires_grad = False
|
277 |
+
|
278 |
+
if freeze_layers != -1:
|
279 |
+
# if freeze_layer_count == -1, we only freeze the embedding layer
|
280 |
+
# otherwise we freeze the first `freeze_layer_count` encoder layers
|
281 |
+
for layer in model.pretrained_model.base_model.layer[:freeze_layers]:
|
282 |
+
for param in layer.parameters():
|
283 |
+
param.requires_grad = False
|
284 |
+
return model
|
285 |
+
|
286 |
+
def pos_encoding(pos, d, n=10000):
|
287 |
+
encoding_list = []
|
288 |
+
for p in pos:
|
289 |
+
P = np.zeros(d)
|
290 |
+
for i in np.arange(int(d/2)):
|
291 |
+
denominator = np.power(n, 2*i/d)
|
292 |
+
P[2*i] = np.sin(p/denominator)
|
293 |
+
P[2*i+1] = np.cos(p/denominator)
|
294 |
+
encoding_list.append(P)
|
295 |
+
return torch.tensor(np.array(encoding_list))
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
train.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.run import NLP_classification
|
2 |
+
import wandb
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
config_file = 'configs/config.yaml'
|
6 |
+
config = yaml.load(open(config_file), Loader=yaml.FullLoader)
|
7 |
+
|
8 |
+
trainer = NLP_classification(**config['model'])
|
9 |
+
|
10 |
+
trainer.training(**config['train'], **config['wandb'])
|
11 |
+
|
12 |
+
wandb.finish()
|