Text Classification
PyTorch
English
eurovoc
scampion commited on
Commit
b552d82
Β·
verified Β·
0 Parent(s):

initial commit

Browse files
Files changed (11) hide show
  1. .gitattributes +35 -0
  2. .gitignore +2 -0
  3. README.md +100 -0
  4. eurovoc.py +212 -0
  5. handler.py +74 -0
  6. img/architecture.png +0 -0
  7. mlb.pickle +3 -0
  8. pytorch_model.bin +3 -0
  9. requirements.txt +6 -0
  10. test_handler.py +18 -0
  11. train.ipynb +935 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea
2
+ __pycache__
README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: eupl-1.1
3
+ datasets:
4
+ - EuropeanParliament/cellar_eurovoc
5
+ language:
6
+ - en
7
+ metrics:
8
+ - type: f1
9
+ value: 0.72
10
+ name: micro F1
11
+ args:
12
+ threshold: 0.34
13
+ - type: NDCG@3
14
+ value: 0.84
15
+ name: NDCG@5
16
+ - type: NDCG@5
17
+ value: 0.80
18
+ name: NDCG@5
19
+ - type: NDCG@10
20
+ value: 0.83
21
+ name: NDCG@10
22
+ tags:
23
+ - eurovoc
24
+ pipeline_tag: text-classification
25
+
26
+ widget:
27
+ - text: "The Union condemns the continuing grave human rights violations by the Myanmar armed forces, including torture, sexual and gender-based violence, the persecution of civil society actors, human rights defenders and journalists, and attacks on the civilian population, including ethnic and religious minorities."
28
+
29
+ ---
30
+
31
+ # Eurovoc Multilabel Classifer
32
+
33
+ [EuroVoc](https://op.europa.eu/fr/web/eu-vocabularies) is a large multidisciplinary multilingual hierarchical thesaurus of more than 7000 classes covering the activities of EU institutions.
34
+ Given the number of legal documents produced every day and the huge mass of pre-existing documents to be classified high quality automated or semi-automated classification methods are most welcome in this domain.
35
+
36
+ This model based on BERT Deep Neural Network was trained on more than 200,000 documents to achieve that task and is used in a production environment via the huggingface inference endpoint.
37
+
38
+
39
+ ## Architecture
40
+
41
+ ![architecture](img/architecture.png)
42
+
43
+ 7331 Eurovoc labels
44
+
45
+ ## Usage
46
+
47
+ ```python
48
+ from eurovoc import EurovocTagger
49
+ model = EurovocTagger.from_pretrained("EuropeanParliament/eurovoc_en")
50
+ ```
51
+
52
+ ## Metrics
53
+
54
+
55
+ ### Eurlex57k Dataset
56
+
57
+ | Metric | Value | Threshold Value |
58
+ |------------|----------|-----------------|
59
+ | Micro F1 | 0.7233 | 0.34 |
60
+ | NDCG@3 | 0.8438 | - |
61
+ | NDCG@5 | 0.8079 | - |
62
+ | NDCG@10 | 0.833 | - |
63
+
64
+ These values are in line with the state of the art in the field, see the publication [Large Scale Legal Text Classification Using Transformer Models](https://arxiv.org/pdf/2010.12871.pdf).
65
+
66
+
67
+ ## Inference Endpoint
68
+
69
+ Member of the [European Parliament HuggingFace Organisation](https://huggingface.co/EuropeanParliament) can access to our inference endpoint.
70
+
71
+ ### Payload example
72
+
73
+ ```json
74
+ {
75
+ "inputs": "The Union condemns the continuing grave human rights violations by the Myanmar armed forces, including torture, sexual and gender-based violence, the persecution of civil society actors, human rights defenders and journalists, and attacks on the civilian population, including ethnic and religious minorities. ",
76
+ "topk": 10,
77
+ "threshold": 0.16
78
+ }
79
+
80
+ ```
81
+
82
+ result:
83
+
84
+ ```json
85
+ {'results': [{'label': 'international sanctions', 'score': 0.9994925260543823},
86
+ {'label': 'economic sanctions', 'score': 0.9991770386695862},
87
+ {'label': 'natural person', 'score': 0.9591936469078064},
88
+ {'label': 'EU restrictive measure', 'score': 0.8388392329216003},
89
+ {'label': 'legal person', 'score': 0.45630475878715515},
90
+ {'label': 'Burma/Myanmar', 'score': 0.43375277519226074}]}
91
+ ```
92
+
93
+ Only six results, because the following one score is less that 0.16
94
+
95
+ Default value, topk = 5 and threshold = 0.16
96
+
97
+
98
+ ## Author(s)
99
+
100
+ SΓ©bastien Campion <[email protected]>
eurovoc.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch.nn as nn
6
+ from transformers import BertTokenizerFast as BertTokenizer, AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModel
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+
10
+ class EurovocDataset(Dataset):
11
+
12
+ def __init__(
13
+ self,
14
+ text: np.array,
15
+ labels: np.array,
16
+ tokenizer: BertTokenizer,
17
+ max_token_len: int = 128
18
+ ):
19
+ self.tokenizer = tokenizer
20
+ self.text = text
21
+ self.labels = labels
22
+ self.max_token_len = max_token_len
23
+
24
+ def __len__(self):
25
+ return len(self.labels)
26
+
27
+ def __getitem__(self, index: int):
28
+ text = self.text[index][0]
29
+ labels = self.labels[index]
30
+
31
+ encoding = self.tokenizer.encode_plus(
32
+ text,
33
+ add_special_tokens=True,
34
+ max_length=self.max_token_len,
35
+ return_token_type_ids=False,
36
+ padding="max_length",
37
+ truncation=True,
38
+ return_attention_mask=True,
39
+ return_tensors='pt',
40
+ )
41
+
42
+ return dict(
43
+ text=text,
44
+ input_ids=encoding["input_ids"].flatten(),
45
+ attention_mask=encoding["attention_mask"].flatten(),
46
+ labels=torch.FloatTensor(labels)
47
+ )
48
+
49
+
50
+ class EuroVocLongTextDataset(Dataset):
51
+
52
+ def __splitter__(text, max_lenght):
53
+ l = text.split()
54
+ for i in range(0, len(l), max_lenght):
55
+ yield l[i:i + max_lenght]
56
+
57
+ def __init__(
58
+ self,
59
+ text: np.array,
60
+ labels: np.array,
61
+ tokenizer: BertTokenizer,
62
+ max_token_len: int = 128
63
+ ):
64
+ self.tokenizer = tokenizer
65
+ self.text = text
66
+ self.labels = labels
67
+ self.max_token_len = max_token_len
68
+
69
+ self.chunks_and_labels = [(c, l) for t, l in zip(self.text, self.labels) for c in self.__splitter__(t)]
70
+
71
+ self.encoding = self.tokenizer.batch_encode_plus(
72
+ [c for c, _ in self.chunks_and_labels],
73
+ add_special_tokens=True,
74
+ max_length=self.max_token_len,
75
+ return_token_type_ids=False,
76
+ padding="max_length",
77
+ truncation=True,
78
+ return_attention_mask=True,
79
+ return_tensors='pt',
80
+ )
81
+
82
+ def __len__(self):
83
+ return len(self.chunks_and_labels)
84
+
85
+ def __getitem__(self, index: int):
86
+ text, labels = self.chunks_and_labels[index]
87
+
88
+ return dict(
89
+ text=text,
90
+ input_ids=self.encoding[index]["input_ids"].flatten(),
91
+ attention_mask=self.encoding[index]["attention_mask"].flatten(),
92
+ labels=torch.FloatTensor(labels)
93
+ )
94
+
95
+
96
+ class EurovocDataModule(pl.LightningDataModule):
97
+
98
+ def __init__(self, bert_model_name, x_tr, y_tr, x_test, y_test, batch_size=8, max_token_len=512):
99
+ super().__init__()
100
+
101
+ self.batch_size = batch_size
102
+ self.x_tr = x_tr
103
+ self.y_tr = y_tr
104
+ self.x_test = x_test
105
+ self.y_test = y_test
106
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
107
+ self.max_token_len = max_token_len
108
+
109
+ def setup(self, stage=None):
110
+ self.train_dataset = EurovocDataset(
111
+ self.x_tr,
112
+ self.y_tr,
113
+ self.tokenizer,
114
+ self.max_token_len
115
+ )
116
+
117
+ self.test_dataset = EurovocDataset(
118
+ self.x_test,
119
+ self.y_test,
120
+ self.tokenizer,
121
+ self.max_token_len
122
+ )
123
+
124
+ def train_dataloader(self):
125
+ return DataLoader(
126
+ self.train_dataset,
127
+ batch_size=self.batch_size,
128
+ shuffle=True,
129
+ num_workers=2
130
+ )
131
+
132
+ def val_dataloader(self):
133
+ return DataLoader(
134
+ self.test_dataset,
135
+ batch_size=self.batch_size,
136
+ num_workers=2
137
+ )
138
+
139
+ def test_dataloader(self):
140
+ return DataLoader(
141
+ self.test_dataset,
142
+ batch_size=self.batch_size,
143
+ num_workers=2
144
+ )
145
+
146
+
147
+ class EurovocTagger(pl.LightningModule, PyTorchModelHubMixin):
148
+
149
+ def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
150
+ super().__init__()
151
+ self.bert = AutoModel.from_pretrained(bert_model_name)
152
+ self.dropout = nn.Dropout(p=0.2)
153
+ self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
154
+ self.criterion = nn.BCELoss()
155
+ self.lr = lr
156
+ self.eps = eps
157
+
158
+ def forward(self, input_ids, attention_mask, labels=None):
159
+ output = self.bert(input_ids, attention_mask=attention_mask)
160
+ output = self.dropout(output.pooler_output)
161
+ output = self.classifier1(output)
162
+ output = torch.sigmoid(output)
163
+ loss = 0
164
+ if labels is not None:
165
+ loss = self.criterion(output, labels)
166
+ return loss, output
167
+
168
+ def training_step(self, batch, batch_idx):
169
+ input_ids = batch["input_ids"]
170
+ attention_mask = batch["attention_mask"]
171
+ labels = batch["labels"]
172
+ loss, outputs = self(input_ids, attention_mask, labels)
173
+ self.log("train_loss", loss, prog_bar=True, logger=True)
174
+ return {"loss": loss, "predictions": outputs, "labels": labels}
175
+
176
+ def validation_step(self, batch, batch_idx):
177
+ input_ids = batch["input_ids"]
178
+ attention_mask = batch["attention_mask"]
179
+ labels = batch["labels"]
180
+ loss, outputs = self(input_ids, attention_mask, labels)
181
+ self.log("val_loss", loss, prog_bar=True, logger=True)
182
+ return loss
183
+
184
+ def test_step(self, batch, batch_idx):
185
+ input_ids = batch["input_ids"]
186
+ attention_mask = batch["attention_mask"]
187
+ labels = batch["labels"]
188
+ loss, outputs = self(input_ids, attention_mask, labels)
189
+ self.log("test_loss", loss, prog_bar=True, logger=True)
190
+ return loss
191
+
192
+ def on_train_epoch_end(self, *args, **kwargs):
193
+ return
194
+ #labels = []
195
+ #predictions = []
196
+ #for output in args['outputs']:
197
+ # for out_labels in output["labels"].detach().cpu():
198
+ # labels.append(out_labels)
199
+ # for out_predictions in output["predictions"].detach().cpu():
200
+ # predictions.append(out_predictions)
201
+
202
+ #labels = torch.stack(labels).int()
203
+ #predictions = torch.stack(predictions)
204
+
205
+ #for i, name in enumerate(mlb.classes_):
206
+ # class_roc_auc = auroc(predictions[:, i], labels[:, i])
207
+ # self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)
208
+
209
+
210
+ def configure_optimizers(self):
211
+ return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
212
+
handler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import numpy as np
3
+ import pickle
4
+
5
+ from sklearn.preprocessing import MultiLabelBinarizer
6
+ from transformers import AutoTokenizer
7
+ import torch
8
+
9
+ from eurovoc import EurovocTagger
10
+
11
+ BERT_MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
12
+ MAX_LEN = 512
13
+ TEXT_MAX_LEN = MAX_LEN * 50
14
+ tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
15
+
16
+
17
+ class EndpointHandler:
18
+ mlb = MultiLabelBinarizer()
19
+
20
+ def __init__(self, path=""):
21
+ self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))
22
+
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ self.model = EurovocTagger.from_pretrained(path,
25
+ bert_model_name=BERT_MODEL_NAME,
26
+ n_classes=len(self.mlb.classes_),
27
+ map_location=self.device)
28
+ self.model.eval()
29
+ self.model.freeze()
30
+
31
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
32
+ """
33
+ data args:
34
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
35
+ kwargs
36
+ Return:
37
+ A :obj:`list` | `dict`: will be serialized and returned
38
+ """
39
+
40
+ text = data.pop("inputs", data)
41
+ topk = data.pop("topk", 5)
42
+ threshold = data.pop("threshold", 0.16)
43
+ debug = data.pop("debug", False)
44
+ prediction = self.get_prediction(text)
45
+ results = [{"label": label, "score": float(score)} for label, score in
46
+ zip(self.mlb.classes_, prediction[0].tolist())]
47
+ results = sorted(results, key=lambda x: x["score"], reverse=True)
48
+ results = [r for r in results if r["score"] > threshold]
49
+ results = results[:topk]
50
+ if debug:
51
+ return {"results": results, "values": prediction, "input": text}
52
+ else:
53
+ return {"results": results}
54
+
55
+ def get_prediction(self, text):
56
+ # split text into chunks of MAX_LEN and get average prediction for each chunk
57
+ chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
58
+ predictions = [self._get_prediction(chunk) for chunk in chunks]
59
+ predictions = np.array(predictions).mean(axis=0)
60
+ return predictions
61
+
62
+ def _get_prediction(self, text):
63
+ item = tokenizer.encode_plus(
64
+ text,
65
+ add_special_tokens=True,
66
+ max_length=MAX_LEN,
67
+ return_token_type_ids=False,
68
+ padding="max_length",
69
+ truncation=True,
70
+ return_attention_mask=True,
71
+ return_tensors='pt')
72
+ _, prediction = self.model(item["input_ids"], item["attention_mask"])
73
+ prediction = prediction.cpu().detach().numpy()
74
+ return prediction
img/architecture.png ADDED
mlb.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb8822c2c0cee9ceeadab0afbb155106d7f55fafa58e5a16eac3280aaf9cc980
3
+ size 128152
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57719b9fd61bbe3141cfc0d38291404337dab436cc5be4ab257e88498e636e88
3
+ size 458391285
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ aiohttp==3.8.5
2
+ ipython==8.14.0
3
+ pip-chill==1.0.3
4
+ pytorch-lightning==2.0.5
5
+ scikit-learn==1.3.0
6
+ transformers==4.32.0
test_handler.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ from handler import EndpointHandler
3
+
4
+ # init handler
5
+ my_handler = EndpointHandler(path=".")
6
+
7
+ # prepare sample payload
8
+
9
+ payload = {"text": "EN Official Journal of the European Union LI 183/19 COUNCIL IMPLEMENTING REGULATION (EU) 2023/1497 of 20Β July 2023 implementing Regulation (EU) NoΒ 401/2013 concerning restrictive measures in view of the situation inΒ Myanmar/Burma THE COUNCIL OF THE EUROPEAN UNION, Having regard to the Treaty on the Functioning of the European Union, Having regard to Council Regulation (EU) NoΒ 401/2013 of 2Β May 2013 concerning restrictive measures in view of the situation inΒ Myanmar/Burma and repealing Regulation (EC) NoΒ 194/2008Β (1), and in particular ArticleΒ 4i thereof, Having regard to the proposal from the High Representative of the Union for Foreign Affairs and Security Policy, Whereas: (1) On 2Β May 2013, the Council adopted Regulation (EU) NoΒ 401/2013. (2) On 31Β January 2023, the High Representative of the Union for Foreign Affairs and Security Policy issued a declaration on behalf of the Union strongly condemning the overthrow of Myanmar’s democratically-elected government by the Myanmar armed forces in blatant violation of the will of the people as expressed in the general election of 8Β November 2020. This illegitimate act reversed the country’s democratic transition and led to disastrous humanitarian, social, security, economic and human rights consequences. (3) The Union remains deeply concerned by the continuing escalation of violence and the evolution towards a protracted conflict with regional implications. The Union condemns the continuing grave human rights violations by the Myanmar armed forces, including torture, sexual and gender-based violence, the persecution of civil society actors, human rights defenders and journalists, and attacks on the civilian population, including ethnic and religious minorities. (4) In the absence of swift progress in the situation inΒ Myanmar/Burma, the Union has expressed several times its readiness to adopt further restrictive measures against those responsible for undermining democracy and the rule of law and for the serious human rights violations taking place in that country. (5) In view of the continuing grave situation inΒ Myanmar/Burma, six persons and one entity should be added to the list of natural and legal persons, entities and bodies subject to restrictive measures in Annex IV to Regulation (EU) NoΒ 401/2013. (6) Regulation (EU) NoΒ 401/2013 should therefore be amended accordingly, HAS ADOPTED THIS REGULATION: ArticleΒ 1 Annex IV to Regulation (EU) NoΒ 401/2013 is amended as set out in the Annex to this Regulation. ArticleΒ 2 This Regulation shall enter into force on the date of its publication in the Official Journal of the European Union. This Regulation shall be binding in its entirety and directly applicable in all Member States. Done at Brussels, 20Β July 2023. For the Council The President J. BORRELL FONTELLES (1)Β Β OJΒ LΒ 121, 3. 5. 2013, p. 1. ANNEX Annex IV to Regulation (EU) NoΒ 401/2013 is amended as follows: (1) the following entries are added to the list headed β€˜A. Natural persons referred to in ArticleΒ 4a’: Β  Name Identifying information Reasons Date of listing β€˜94. Aung Kyaw Min Nationality: Myanmar/Burma; Date of birth: circa 1958; Place of birth: Myanmar/Burma; Gender: male; Function: Member of State Administration Council Aung Kyaw Min has been a member of the State Administration Council (SAC) since 1Β February 2023. He is also the former Chief-Minister of Rakhine State. SAC is led by Commander in Chief Min Aung Hlaing, who took over the legislative, executive and judicial powers of the State as of 1Β February 2021, preventing the democratically-elected government from fulfilling its mandate. As member of the SAC, Aung Kyaw Min has been directly involved in and responsible for decision-making concerning state functions and is therefore responsible for undermining democracy and the rule of law in Myanmar/Burma. Additionally, the SAC has adopted decisions restricting the rights of freedom of expression, including access to information, and peaceful assembly. The military forces and authorities operating under the control of the SAC have committed serious human rights violations since 1Β February 2021, killing civilian and unarmed protestors, and have restricted freedom of assembly and of expression. As a member of the SAC, Aung Kyaw Min is directly responsible for those repressive decisions and for serious human rights violations. 20. 7. 2023 95. Kyaw Swar Lin a. k. a Kyaw Swar Linn Nationality: Myanmar/Burma; Place of birth: Myanmar/Burma; Gender: male; Function: Quartermaster General of the Myanmar armed forces Lieutenant General Kyaw Swar Lin was been appointed as Quartermaster General in May 2020. It is the sixth highest position in the military of Myanmar/Burma. The Office of the Quartermaster General is a department under the jurisdiction of the Ministry of Defense and is involved in arms and military equipment procurement for the Myanmar Armed Forces. In addition, Kyaw Swar Lin runs the Myanmar Economic Corporation (MEC), which is one of the two major conglomerates and holding companies operated by the military, generating revenue for the Myanmar armed forces (Tatmadaw). As Quartermaster General, he forms part of the military regime which has seized power in a military coup and overthrown the legitimately elected leaders of Myanmar/Burma. Kyaw Swar Lin is therefore a natural person whose policies and activities undermine democracy and the rule of law in Myanmar/Burma, and who provides support for actions that threaten the peace, security and stability of Myanmar/Burma. 20. 7. 2023 96. Myint Kyaing a. k. a. U Myint Kyaing Nationality: Myanmar/Burma; Date of birth: 17. 4. 1957 Place of birth: Myanmar/Burma; Gender: male; Function: Union Minister of Immigration and Population Myint Kyaing has been the Union Minister for Immigration and Population since 19Β August 2021. Before that, he was Union Minister of Labour following the coup of 1Β February 2021. He is a member of the State Administration Council (SAC), led by Commander-in-Chief Min Aung Hlaing, which took over the legislative, executive and judicial powers of the State in a military coup on 1Β February 2021. As a government Minister, he forms part of the military regime which has seized power in a military coup and overthrown the legitimately elected leaders of Myanmar/Burma. In his capacity as Union Minister, he carries out duties in support of military regime’s repressive immigration and population policy such as restrictions for citizens to travel within the country as well as the policy of the regime towards the minority of the Rohingya in violation of human rights. As Minister for Immigration and Population he also participates in preparations for the elections announced by the military in order to legitimise the illegal coup of February 2021. Myint Kyaing is therefore responsible for undermining democracy and the rule of law in Myanmar/Burma and for providing support for actions that threaten the peace, security and stability of Myanmar/Burma. 20. 7. 2023 97.",
10
+ "topk": 10,
11
+ "threshold": 0.16
12
+ }
13
+
14
+ # test the handler
15
+ payload_pred = my_handler(payload)
16
+
17
+ pprint(payload_pred)
18
+
train.ipynb ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a8b6caed",
6
+ "metadata": {},
7
+ "source": [
8
+ "# πŸ‡ͺπŸ‡Ί 🏷️ Eurovoc Model Training Notebook"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "c4c73793",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pickle \n",
19
+ "import pandas as pd\n",
20
+ "from transformers import AutoTokenizer, AutoModel\n",
21
+ "\n",
22
+ "from datasets import list_datasets, load_dataset\n",
23
+ "\n",
24
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
25
+ "import torch\n",
26
+ "\n",
27
+ "import pytorch_lightning as pl\n",
28
+ "from pytorch_lightning.callbacks import ModelCheckpoint"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "id": "dc770f0b",
34
+ "metadata": {
35
+ "tags": []
36
+ },
37
+ "source": [
38
+ "---\n",
39
+ "\n",
40
+ "## 1. Data loading\n",
41
+ "### Choose our dataset, extracted from ep registry or eurlex57k"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 2,
47
+ "id": "9fdc5328",
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "name": "stderr",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "Found cached dataset json (/home/scampion/.cache/huggingface/datasets/EuropeanParliament___json/EuropeanParliament--cellar_eurovoc-3a27a019ebbf0296/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n"
55
+ ]
56
+ },
57
+ {
58
+ "data": {
59
+ "application/vnd.jupyter.widget-view+json": {
60
+ "model_id": "d5bf91bf9dc2416faefe96d680217da6",
61
+ "version_major": 2,
62
+ "version_minor": 0
63
+ },
64
+ "text/plain": [
65
+ " 0%| | 0/1 [00:00<?, ?it/s]"
66
+ ]
67
+ },
68
+ "metadata": {},
69
+ "output_type": "display_data"
70
+ }
71
+ ],
72
+ "source": [
73
+ "#dataset = load_dataset('json', data_files='ep_registry.jsonl')\n",
74
+ "\n",
75
+ "#dataset = load_dataset('eurlex')\n",
76
+ "dataset = load_dataset('EuropeanParliament/cellar_eurovoc')\n"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "id": "94967fc2",
82
+ "metadata": {},
83
+ "source": [
84
+ "### Merge train, test and validation"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 3,
90
+ "id": "ce5f764f",
91
+ "metadata": {},
92
+ "outputs": [
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "<div>\n",
97
+ "<style scoped>\n",
98
+ " .dataframe tbody tr th:only-of-type {\n",
99
+ " vertical-align: middle;\n",
100
+ " }\n",
101
+ "\n",
102
+ " .dataframe tbody tr th {\n",
103
+ " vertical-align: top;\n",
104
+ " }\n",
105
+ "\n",
106
+ " .dataframe thead th {\n",
107
+ " text-align: right;\n",
108
+ " }\n",
109
+ "</style>\n",
110
+ "<table border=\"1\" class=\"dataframe\">\n",
111
+ " <thead>\n",
112
+ " <tr style=\"text-align: right;\">\n",
113
+ " <th></th>\n",
114
+ " <th>title</th>\n",
115
+ " <th>date</th>\n",
116
+ " <th>eurovoc_concepts</th>\n",
117
+ " <th>url</th>\n",
118
+ " <th>lang</th>\n",
119
+ " <th>formats</th>\n",
120
+ " <th>text</th>\n",
121
+ " </tr>\n",
122
+ " </thead>\n",
123
+ " <tbody>\n",
124
+ " <tr>\n",
125
+ " <th>0</th>\n",
126
+ " <td>Corrigendum to Commission Implementing Regulat...</td>\n",
127
+ " <td>2023-07-20</td>\n",
128
+ " <td>[China, Malaysia, anti-dumping duty, business ...</td>\n",
129
+ " <td>http://publications.europa.eu/resource/cellar/...</td>\n",
130
+ " <td>eng</td>\n",
131
+ " <td>[fmx4, pdfa2a, xhtml]</td>\n",
132
+ " <td>L_2023183EN. 01005801. xml 20. 7. 2023Β Β Β  EN O...</td>\n",
133
+ " </tr>\n",
134
+ " <tr>\n",
135
+ " <th>1</th>\n",
136
+ " <td>Council Decision (CFSP) 2023/1501 of 20Β July 2...</td>\n",
137
+ " <td>2023-07-20</td>\n",
138
+ " <td>[EU restrictive measure, Russia, Ukraine, econ...</td>\n",
139
+ " <td>http://publications.europa.eu/resource/cellar/...</td>\n",
140
+ " <td>eng</td>\n",
141
+ " <td>[fmx4, pdfa2a, xhtml]</td>\n",
142
+ " <td>LI2023183EN. 01004801. xml 20. 7. 2023Β Β Β  EN O...</td>\n",
143
+ " </tr>\n",
144
+ " <tr>\n",
145
+ " <th>2</th>\n",
146
+ " <td>Council Decision (CFSP) 2023/1502 of 20Β July 2...</td>\n",
147
+ " <td>2023-07-20</td>\n",
148
+ " <td>[Burma/Myanmar, EU restrictive measure, econom...</td>\n",
149
+ " <td>http://publications.europa.eu/resource/cellar/...</td>\n",
150
+ " <td>eng</td>\n",
151
+ " <td>[fmx4, pdfa2a, xhtml]</td>\n",
152
+ " <td>LI2023183EN. 01005201. xml 20. 7. 2023Β Β Β  EN O...</td>\n",
153
+ " </tr>\n",
154
+ " <tr>\n",
155
+ " <th>3</th>\n",
156
+ " <td>The Committee of the Regions welcomes Croatian...</td>\n",
157
+ " <td>2023-07-20</td>\n",
158
+ " <td>[Croatia, EU regional policy, European Committ...</td>\n",
159
+ " <td>http://publications.europa.eu/resource/cellar/...</td>\n",
160
+ " <td>eng</td>\n",
161
+ " <td>[pdf]</td>\n",
162
+ " <td>EUROPEAN UNION Committee of the Regions The Co...</td>\n",
163
+ " </tr>\n",
164
+ " <tr>\n",
165
+ " <th>4</th>\n",
166
+ " <td>Corrigendum to Commission Implementing Regulat...</td>\n",
167
+ " <td>2023-07-20</td>\n",
168
+ " <td>[India, TΓΌrkiye, anti-dumping duty, building m...</td>\n",
169
+ " <td>http://publications.europa.eu/resource/cellar/...</td>\n",
170
+ " <td>eng</td>\n",
171
+ " <td>[fmx4, pdfa2a, xhtml]</td>\n",
172
+ " <td>L_2023183EN. 01005901. xml 20. 7. 2023Β Β Β  EN O...</td>\n",
173
+ " </tr>\n",
174
+ " </tbody>\n",
175
+ "</table>\n",
176
+ "</div>"
177
+ ],
178
+ "text/plain": [
179
+ " title date \\\n",
180
+ "0 Corrigendum to Commission Implementing Regulat... 2023-07-20 \n",
181
+ "1 Council Decision (CFSP) 2023/1501 of 20Β July 2... 2023-07-20 \n",
182
+ "2 Council Decision (CFSP) 2023/1502 of 20Β July 2... 2023-07-20 \n",
183
+ "3 The Committee of the Regions welcomes Croatian... 2023-07-20 \n",
184
+ "4 Corrigendum to Commission Implementing Regulat... 2023-07-20 \n",
185
+ "\n",
186
+ " eurovoc_concepts \\\n",
187
+ "0 [China, Malaysia, anti-dumping duty, business ... \n",
188
+ "1 [EU restrictive measure, Russia, Ukraine, econ... \n",
189
+ "2 [Burma/Myanmar, EU restrictive measure, econom... \n",
190
+ "3 [Croatia, EU regional policy, European Committ... \n",
191
+ "4 [India, TΓΌrkiye, anti-dumping duty, building m... \n",
192
+ "\n",
193
+ " url lang \\\n",
194
+ "0 http://publications.europa.eu/resource/cellar/... eng \n",
195
+ "1 http://publications.europa.eu/resource/cellar/... eng \n",
196
+ "2 http://publications.europa.eu/resource/cellar/... eng \n",
197
+ "3 http://publications.europa.eu/resource/cellar/... eng \n",
198
+ "4 http://publications.europa.eu/resource/cellar/... eng \n",
199
+ "\n",
200
+ " formats text \n",
201
+ "0 [fmx4, pdfa2a, xhtml] L_2023183EN. 01005801. xml 20. 7. 2023Β Β Β  EN O... \n",
202
+ "1 [fmx4, pdfa2a, xhtml] LI2023183EN. 01004801. xml 20. 7. 2023Β Β Β  EN O... \n",
203
+ "2 [fmx4, pdfa2a, xhtml] LI2023183EN. 01005201. xml 20. 7. 2023Β Β Β  EN O... \n",
204
+ "3 [pdf] EUROPEAN UNION Committee of the Regions The Co... \n",
205
+ "4 [fmx4, pdfa2a, xhtml] L_2023183EN. 01005901. xml 20. 7. 2023Β Β Β  EN O... "
206
+ ]
207
+ },
208
+ "execution_count": 3,
209
+ "metadata": {},
210
+ "output_type": "execute_result"
211
+ }
212
+ ],
213
+ "source": [
214
+ "train = dataset['train'].to_pandas()\n",
215
+ "test = dataset['test'].to_pandas() if 'test' in dataset.keys() else None\n",
216
+ "validation = dataset['validation'].to_pandas() if 'validation' in dataset.keys() else None\n",
217
+ "\n",
218
+ "all = pd.concat([train, test, validation])#[:1000]\n",
219
+ "all.head()"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 4,
225
+ "id": "4c141dfa",
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "#all['eurovoc_concepts_str'] = all['eurovoc_concepts'].apply(str)"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "id": "aeca89c2",
235
+ "metadata": {
236
+ "tags": []
237
+ },
238
+ "source": [
239
+ "### Create the MultiLabel Binarizer and save it in a file for prediction"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 4,
245
+ "id": "d6846099",
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "data": {
250
+ "text/plain": [
251
+ "('Number of classes', 6835)"
252
+ ]
253
+ },
254
+ "execution_count": 4,
255
+ "metadata": {},
256
+ "output_type": "execute_result"
257
+ }
258
+ ],
259
+ "source": [
260
+ "mlb = MultiLabelBinarizer().fit(all['eurovoc_concepts'])\n",
261
+ "\n",
262
+ "pickle.dump(mlb, open('mlb.pickle', 'wb'))\n",
263
+ "\"Number of classes\", len(mlb.classes_)"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "markdown",
268
+ "id": "1f27b865",
269
+ "metadata": {
270
+ "tags": []
271
+ },
272
+ "source": [
273
+ "---\n",
274
+ "## 2. Split data using iterative train test "
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "ba290237",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "import numpy as np\n",
285
+ "#X = np.array(all['text'].to_list())\n",
286
+ "#X = np.expand_dims(X, axis=1)\n",
287
+ "X = all['text'].to_numpy()\n",
288
+ "X = np.expand_dims(X, axis=1)\n",
289
+ "y = mlb.transform(all['eurovoc_concepts'])\n",
290
+ "\n",
291
+ "\n",
292
+ "from skmultilearn.model_selection import iterative_train_test_split\n",
293
+ "x_tr, y_tr, x_test, y_test = iterative_train_test_split(X, y, test_size = 0.1)\n",
294
+ "x_tr, y_tr, x_val, y_val = iterative_train_test_split(x_tr, y_tr, test_size = 0.1)\n",
295
+ "len(x_tr), len(x_val), len(x_test)"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "id": "98371ad3",
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "# Example \n",
306
+ "i = 10\n",
307
+ "x_tr[i][0][0:120], mlb.inverse_transform(np.expand_dims(y_tr[i], axis=1).T)"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "markdown",
312
+ "id": "7c959b6a",
313
+ "metadata": {},
314
+ "source": [
315
+ "---\n",
316
+ "## 3. Model definition and training"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": null,
322
+ "id": "a177f1ce",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": []
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "id": "f4061399",
331
+ "metadata": {},
332
+ "outputs": [
333
+ {
334
+ "name": "stderr",
335
+ "output_type": "stream",
336
+ "text": [
337
+ "Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
338
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
339
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
340
+ "GPU available: True (cuda), used: True\n",
341
+ "TPU available: False, using: 0 TPU cores\n",
342
+ "IPU available: False, using: 0 IPUs\n",
343
+ "HPU available: False, using: 0 HPUs\n",
344
+ "You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
345
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n",
346
+ "\n",
347
+ " | Name | Type | Params\n",
348
+ "------------------------------------------\n",
349
+ "0 | bert | BertModel | 109 M \n",
350
+ "1 | dropout | Dropout | 0 \n",
351
+ "2 | classifier1 | Linear | 5.1 M \n",
352
+ "3 | criterion | BCELoss | 0 \n",
353
+ "------------------------------------------\n",
354
+ "114 M Trainable params\n",
355
+ "0 Non-trainable params\n",
356
+ "114 M Total params\n",
357
+ "458.304 Total estimated model params size (MB)\n",
358
+ "IOPub message rate exceeded.\n",
359
+ "The Jupyter server will temporarily stop sending output\n",
360
+ "to the client in order to avoid crashing it.\n",
361
+ "To change this limit, set the config variable\n",
362
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
363
+ "\n",
364
+ "Current values:\n",
365
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
366
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
367
+ "\n",
368
+ "IOPub message rate exceeded.\n",
369
+ "The Jupyter server will temporarily stop sending output\n",
370
+ "to the client in order to avoid crashing it.\n",
371
+ "To change this limit, set the config variable\n",
372
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
373
+ "\n",
374
+ "Current values:\n",
375
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
376
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
377
+ "\n",
378
+ "IOPub message rate exceeded.\n",
379
+ "The Jupyter server will temporarily stop sending output\n",
380
+ "to the client in order to avoid crashing it.\n",
381
+ "To change this limit, set the config variable\n",
382
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
383
+ "\n",
384
+ "Current values:\n",
385
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
386
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
387
+ "\n",
388
+ "IOPub message rate exceeded.\n",
389
+ "The Jupyter server will temporarily stop sending output\n",
390
+ "to the client in order to avoid crashing it.\n",
391
+ "To change this limit, set the config variable\n",
392
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
393
+ "\n",
394
+ "Current values:\n",
395
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
396
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
397
+ "\n",
398
+ "IOPub message rate exceeded.\n",
399
+ "The Jupyter server will temporarily stop sending output\n",
400
+ "to the client in order to avoid crashing it.\n",
401
+ "To change this limit, set the config variable\n",
402
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
403
+ "\n",
404
+ "Current values:\n",
405
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
406
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
407
+ "\n",
408
+ "IOPub message rate exceeded.\n",
409
+ "The Jupyter server will temporarily stop sending output\n",
410
+ "to the client in order to avoid crashing it.\n",
411
+ "To change this limit, set the config variable\n",
412
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
413
+ "\n",
414
+ "Current values:\n",
415
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
416
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
417
+ "\n",
418
+ "IOPub message rate exceeded.\n",
419
+ "The Jupyter server will temporarily stop sending output\n",
420
+ "to the client in order to avoid crashing it.\n",
421
+ "To change this limit, set the config variable\n",
422
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
423
+ "\n",
424
+ "Current values:\n",
425
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
426
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
427
+ "\n",
428
+ "IOPub message rate exceeded.\n",
429
+ "The Jupyter server will temporarily stop sending output\n",
430
+ "to the client in order to avoid crashing it.\n",
431
+ "To change this limit, set the config variable\n",
432
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
433
+ "\n",
434
+ "Current values:\n",
435
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
436
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
437
+ "\n",
438
+ "IOPub message rate exceeded.\n",
439
+ "The Jupyter server will temporarily stop sending output\n",
440
+ "to the client in order to avoid crashing it.\n",
441
+ "To change this limit, set the config variable\n",
442
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
443
+ "\n",
444
+ "Current values:\n",
445
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
446
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
447
+ "\n",
448
+ "IOPub message rate exceeded.\n",
449
+ "The Jupyter server will temporarily stop sending output\n",
450
+ "to the client in order to avoid crashing it.\n",
451
+ "To change this limit, set the config variable\n",
452
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
453
+ "\n",
454
+ "Current values:\n",
455
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
456
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
457
+ "\n",
458
+ "IOPub message rate exceeded.\n",
459
+ "The Jupyter server will temporarily stop sending output\n",
460
+ "to the client in order to avoid crashing it.\n",
461
+ "To change this limit, set the config variable\n",
462
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
463
+ "\n",
464
+ "Current values:\n",
465
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
466
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
467
+ "\n",
468
+ "IOPub message rate exceeded.\n",
469
+ "The Jupyter server will temporarily stop sending output\n",
470
+ "to the client in order to avoid crashing it.\n",
471
+ "To change this limit, set the config variable\n",
472
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
473
+ "\n",
474
+ "Current values:\n",
475
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
476
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
477
+ "\n",
478
+ "IOPub message rate exceeded.\n",
479
+ "The Jupyter server will temporarily stop sending output\n",
480
+ "to the client in order to avoid crashing it.\n",
481
+ "To change this limit, set the config variable\n",
482
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
483
+ "\n",
484
+ "Current values:\n",
485
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
486
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
487
+ "\n",
488
+ "IOPub message rate exceeded.\n",
489
+ "The Jupyter server will temporarily stop sending output\n",
490
+ "to the client in order to avoid crashing it.\n",
491
+ "To change this limit, set the config variable\n",
492
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
493
+ "\n",
494
+ "Current values:\n",
495
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
496
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
497
+ "\n",
498
+ "IOPub message rate exceeded.\n",
499
+ "The Jupyter server will temporarily stop sending output\n",
500
+ "to the client in order to avoid crashing it.\n",
501
+ "To change this limit, set the config variable\n",
502
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
503
+ "\n",
504
+ "Current values:\n",
505
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
506
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
507
+ "\n",
508
+ "IOPub message rate exceeded.\n",
509
+ "The Jupyter server will temporarily stop sending output\n",
510
+ "to the client in order to avoid crashing it.\n",
511
+ "To change this limit, set the config variable\n",
512
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
513
+ "\n",
514
+ "Current values:\n",
515
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
516
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
517
+ "\n",
518
+ "IOPub message rate exceeded.\n",
519
+ "The Jupyter server will temporarily stop sending output\n",
520
+ "to the client in order to avoid crashing it.\n",
521
+ "To change this limit, set the config variable\n",
522
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
523
+ "\n",
524
+ "Current values:\n",
525
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
526
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
527
+ "\n",
528
+ "IOPub message rate exceeded.\n",
529
+ "The Jupyter server will temporarily stop sending output\n",
530
+ "to the client in order to avoid crashing it.\n",
531
+ "To change this limit, set the config variable\n",
532
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
533
+ "\n",
534
+ "Current values:\n",
535
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
536
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
537
+ "\n",
538
+ "IOPub message rate exceeded.\n",
539
+ "The Jupyter server will temporarily stop sending output\n",
540
+ "to the client in order to avoid crashing it.\n",
541
+ "To change this limit, set the config variable\n",
542
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
543
+ "\n",
544
+ "Current values:\n",
545
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
546
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
547
+ "\n",
548
+ "IOPub message rate exceeded.\n",
549
+ "The Jupyter server will temporarily stop sending output\n",
550
+ "to the client in order to avoid crashing it.\n",
551
+ "To change this limit, set the config variable\n",
552
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
553
+ "\n",
554
+ "Current values:\n",
555
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
556
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
557
+ "\n",
558
+ "IOPub message rate exceeded.\n",
559
+ "The Jupyter server will temporarily stop sending output\n",
560
+ "to the client in order to avoid crashing it.\n",
561
+ "To change this limit, set the config variable\n",
562
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
563
+ "\n",
564
+ "Current values:\n",
565
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
566
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
567
+ "\n",
568
+ "IOPub message rate exceeded.\n",
569
+ "The Jupyter server will temporarily stop sending output\n",
570
+ "to the client in order to avoid crashing it.\n",
571
+ "To change this limit, set the config variable\n",
572
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
573
+ "\n",
574
+ "Current values:\n",
575
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
576
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
577
+ "\n",
578
+ "IOPub message rate exceeded.\n",
579
+ "The Jupyter server will temporarily stop sending output\n",
580
+ "to the client in order to avoid crashing it.\n",
581
+ "To change this limit, set the config variable\n",
582
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
583
+ "\n",
584
+ "Current values:\n",
585
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
586
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
587
+ "\n",
588
+ "IOPub message rate exceeded.\n",
589
+ "The Jupyter server will temporarily stop sending output\n",
590
+ "to the client in order to avoid crashing it.\n",
591
+ "To change this limit, set the config variable\n",
592
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
593
+ "\n",
594
+ "Current values:\n",
595
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
596
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
597
+ "\n",
598
+ "IOPub message rate exceeded.\n",
599
+ "The Jupyter server will temporarily stop sending output\n",
600
+ "to the client in order to avoid crashing it.\n",
601
+ "To change this limit, set the config variable\n",
602
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
603
+ "\n",
604
+ "Current values:\n",
605
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
606
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
607
+ "\n",
608
+ "IOPub message rate exceeded.\n",
609
+ "The Jupyter server will temporarily stop sending output\n",
610
+ "to the client in order to avoid crashing it.\n",
611
+ "To change this limit, set the config variable\n",
612
+ "`--ServerApp.iopub_msg_rate_limit`.\n",
613
+ "\n",
614
+ "Current values:\n",
615
+ "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
616
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
617
+ "\n"
618
+ ]
619
+ }
620
+ ],
621
+ "source": [
622
+ "%%capture output\n",
623
+ "%load_ext autoreload\n",
624
+ "%autoreload 2\n",
625
+ "\n",
626
+ "from eurovoc import EurovocTagger, EurovocDataset, EurovocDataModule\n",
627
+ "\n",
628
+ "\n",
629
+ "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
630
+ "N_EPOCHS = 30\n",
631
+ "BATCH_SIZE = 10\n",
632
+ "MAX_LEN = 512\n",
633
+ "LR = 5e-05\n",
634
+ "\n",
635
+ "\n",
636
+ "# Instantiate and set up the data_module\n",
637
+ "dataloader = EurovocDataModule(BERT_MODEL_NAME, x_tr, y_tr, x_val, y_val , BATCH_SIZE, MAX_LEN)\n",
638
+ "dataloader.setup()\n",
639
+ "\n",
640
+ "\n",
641
+ "model = EurovocTagger(BERT_MODEL_NAME, len(mlb.classes_), lr=LR)\n",
642
+ "\n",
643
+ "checkpoint_callback = ModelCheckpoint(\n",
644
+ " monitor='val_loss',\n",
645
+ " filename='EurovocTagger-{epoch:02d}-{val_loss:.2f}',\n",
646
+ " mode='min',\n",
647
+ ")\n",
648
+ "\n",
649
+ "trainer = pl.Trainer(max_epochs=N_EPOCHS , accelerator=\"gpu\", devices=1, callbacks=[checkpoint_callback])#,strategy=\"ddp_notebook\")\n",
650
+ "trainer.fit(model, dataloader)"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "execution_count": 13,
656
+ "id": "19084e69",
657
+ "metadata": {},
658
+ "outputs": [],
659
+ "source": [
660
+ "trainer.save_checkpoint(\"eurovoc_cellar.ckpt\")"
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": null,
666
+ "id": "d8289db5",
667
+ "metadata": {},
668
+ "outputs": [],
669
+ "source": [
670
+ "output()"
671
+ ]
672
+ },
673
+ {
674
+ "cell_type": "code",
675
+ "execution_count": 14,
676
+ "id": "7c250c40",
677
+ "metadata": {},
678
+ "outputs": [],
679
+ "source": [
680
+ "np.save('x_test', x_test)\n",
681
+ "np.save('y_test', y_test)"
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": 15,
687
+ "id": "418a7fd0",
688
+ "metadata": {},
689
+ "outputs": [
690
+ {
691
+ "name": "stderr",
692
+ "output_type": "stream",
693
+ "text": [
694
+ "/home/scampion/training/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:148: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n",
695
+ " rank_zero_warn(\n",
696
+ "Restoring states from the checkpoint path at /home/scampion/training/lightning_logs/version_9/checkpoints/EurovocTagger-epoch=06-val_loss=0.00.ckpt\n",
697
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n",
698
+ "Loaded model weights from the checkpoint at /home/scampion/training/lightning_logs/version_9/checkpoints/EurovocTagger-epoch=06-val_loss=0.00.ckpt\n",
699
+ "/home/scampion/training/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
700
+ " rank_zero_warn(\n"
701
+ ]
702
+ },
703
+ {
704
+ "data": {
705
+ "application/vnd.jupyter.widget-view+json": {
706
+ "model_id": "52fdd2fcc27744c4955dc449cc126100",
707
+ "version_major": 2,
708
+ "version_minor": 0
709
+ },
710
+ "text/plain": [
711
+ "Testing: 0it [00:00, ?it/s]"
712
+ ]
713
+ },
714
+ "metadata": {},
715
+ "output_type": "display_data"
716
+ },
717
+ {
718
+ "data": {
719
+ "text/html": [
720
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
721
+ "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
722
+ "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
723
+ "β”‚<span style=\"color: #008080; text-decoration-color: #008080\"> test_loss </span>β”‚<span style=\"color: #800080; text-decoration-color: #800080\"> 0.0031269278842955828 </span>β”‚\n",
724
+ "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
725
+ "</pre>\n"
726
+ ],
727
+ "text/plain": [
728
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
729
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
730
+ "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
731
+ "β”‚\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.0031269278842955828 \u001b[0m\u001b[35m \u001b[0mβ”‚\n",
732
+ "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n"
733
+ ]
734
+ },
735
+ "metadata": {},
736
+ "output_type": "display_data"
737
+ },
738
+ {
739
+ "data": {
740
+ "text/plain": [
741
+ "[{'test_loss': 0.0031269278842955828}]"
742
+ ]
743
+ },
744
+ "execution_count": 15,
745
+ "metadata": {},
746
+ "output_type": "execute_result"
747
+ }
748
+ ],
749
+ "source": [
750
+ "trainer.test(dataloaders=dataloader)"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "markdown",
755
+ "id": "66b871ec",
756
+ "metadata": {},
757
+ "source": [
758
+ "# Evaluation"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "execution_count": 16,
764
+ "id": "ba317c3e",
765
+ "metadata": {},
766
+ "outputs": [
767
+ {
768
+ "data": {
769
+ "text/plain": [
770
+ "'/home/scampion/training/lightning_logs/version_9/checkpoints/EurovocTagger-epoch=06-val_loss=0.00.ckpt'"
771
+ ]
772
+ },
773
+ "execution_count": 16,
774
+ "metadata": {},
775
+ "output_type": "execute_result"
776
+ }
777
+ ],
778
+ "source": [
779
+ "best_model_path = trainer.checkpoint_callback.best_model_path\n",
780
+ "best_model_path"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "code",
785
+ "execution_count": 17,
786
+ "id": "fe9751a1",
787
+ "metadata": {},
788
+ "outputs": [
789
+ {
790
+ "name": "stderr",
791
+ "output_type": "stream",
792
+ "text": [
793
+ "Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
794
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
795
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
796
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 23243/23243 [16:20<00:00, 23.72it/s] \n"
797
+ ]
798
+ }
799
+ ],
800
+ "source": [
801
+ "from tqdm import tqdm\n",
802
+ "from transformers import AutoTokenizer\n",
803
+ "\n",
804
+ "trained_model = EurovocTagger.load_from_checkpoint(best_model_path,\n",
805
+ " bert_model_name=BERT_MODEL_NAME,\n",
806
+ " n_classes=len(mlb.classes_))\n",
807
+ "trained_model.eval()\n",
808
+ "trained_model.freeze()\n",
809
+ "\n",
810
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
811
+ "trained_model = trained_model.to(device)\n",
812
+ "\n",
813
+ "tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)\n",
814
+ "\n",
815
+ "val_dataset = EurovocDataset(x_test, y_test, tokenizer, max_token_len=MAX_LEN)\n",
816
+ "predictions = []\n",
817
+ "labels = []\n",
818
+ "\n",
819
+ "for item in tqdm(val_dataset):\n",
820
+ " _, prediction = trained_model(\n",
821
+ " item[\"input_ids\"].unsqueeze(dim=0).to(device), \n",
822
+ " item[\"attention_mask\"].unsqueeze(dim=0).to(device)\n",
823
+ " )\n",
824
+ " predictions.append(prediction.flatten())\n",
825
+ " labels.append(item[\"labels\"].int())\n",
826
+ "\n",
827
+ "predictions = torch.stack(predictions).detach().cpu()\n",
828
+ "labels = torch.stack(labels).detach().cpu()"
829
+ ]
830
+ },
831
+ {
832
+ "cell_type": "markdown",
833
+ "id": "67477f7f",
834
+ "metadata": {},
835
+ "source": [
836
+ "### F1 Score"
837
+ ]
838
+ },
839
+ {
840
+ "cell_type": "code",
841
+ "execution_count": 18,
842
+ "id": "f0265f6e",
843
+ "metadata": {},
844
+ "outputs": [
845
+ {
846
+ "name": "stdout",
847
+ "output_type": "stream",
848
+ "text": [
849
+ "0.01 tensor(0.2188)\n",
850
+ "0.06 tensor(0.3929)\n",
851
+ "0.11 tensor(0.4353)\n",
852
+ "0.16 tensor(0.4462)\n",
853
+ "0.21 tensor(0.4437)\n",
854
+ "0.26 tensor(0.4364)\n",
855
+ "0.31 tensor(0.4249)\n",
856
+ "0.36 tensor(0.4106)\n",
857
+ "0.41 tensor(0.3947)\n",
858
+ "0.46 tensor(0.3780)\n",
859
+ "0.51 tensor(0.3597)\n",
860
+ "0.56 tensor(0.3404)\n",
861
+ "0.61 tensor(0.3209)\n",
862
+ "0.66 tensor(0.3007)\n"
863
+ ]
864
+ }
865
+ ],
866
+ "source": [
867
+ "from torchmetrics import F1Score\n",
868
+ "for i in range(1, 70, 5):\n",
869
+ " f1 = F1Score(task=\"multilabel\", num_labels=len(mlb.classes_), average='weighted', threshold= i / 100.0)\n",
870
+ " print(i / 100.0, f1(predictions, labels))"
871
+ ]
872
+ },
873
+ {
874
+ "cell_type": "markdown",
875
+ "id": "0945ad49",
876
+ "metadata": {},
877
+ "source": [
878
+ "### NDCG Score"
879
+ ]
880
+ },
881
+ {
882
+ "cell_type": "code",
883
+ "execution_count": null,
884
+ "id": "e4e3291f",
885
+ "metadata": {},
886
+ "outputs": [],
887
+ "source": [
888
+ "from sklearn.metrics import ndcg_score\n",
889
+ "def calculate_average_ndcg(predictions, labels, top_k=5):\n",
890
+ " # Initialize a list to store NDCG scores for each sample\n",
891
+ " ndcg_scores = []\n",
892
+ "\n",
893
+ " # Calculate NDCG for each sample\n",
894
+ " for i in range(len(predictions)):\n",
895
+ " # Convert tensors to numpy arrays\n",
896
+ " y_true = labels[i].cpu().numpy().reshape(1, -1)\n",
897
+ " y_score = predictions[i].cpu().numpy().reshape(1, -1)\n",
898
+ " \n",
899
+ " # Calculate NDCG for the sample\n",
900
+ " ndcg = ndcg_score(y_true, y_score, k=top_k)\n",
901
+ " ndcg_scores.append(ndcg)\n",
902
+ "\n",
903
+ " # Calculate the average NDCG score\n",
904
+ " average_ndcg = np.mean(ndcg_scores)\n",
905
+ " \n",
906
+ " return average_ndcg\n",
907
+ "\n",
908
+ "for k in [3, 5, 10]:\n",
909
+ " average = calculate_average_ndcg(predictions, labels, top_k=k)\n",
910
+ " print(\"NDCG@\"+str(k)+\": \"+ str(round(average, 4)))"
911
+ ]
912
+ }
913
+ ],
914
+ "metadata": {
915
+ "kernelspec": {
916
+ "display_name": "eurovoc-env",
917
+ "language": "python",
918
+ "name": "eurovoc-env"
919
+ },
920
+ "language_info": {
921
+ "codemirror_mode": {
922
+ "name": "ipython",
923
+ "version": 3
924
+ },
925
+ "file_extension": ".py",
926
+ "mimetype": "text/x-python",
927
+ "name": "python",
928
+ "nbconvert_exporter": "python",
929
+ "pygments_lexer": "ipython3",
930
+ "version": "3.10.12"
931
+ }
932
+ },
933
+ "nbformat": 4,
934
+ "nbformat_minor": 5
935
+ }