Update README.md
Browse files
README.md
CHANGED
@@ -9,7 +9,7 @@ We introduce the model for multilabel ESG risks classification. There is 47 clas
|
|
9 |
|
10 |
## Usage
|
11 |
```python
|
12 |
-
|
13 |
from transformers import MPNetPreTrainedModel, MPNetModel, AutoTokenizer
|
14 |
import torch
|
15 |
#Mean Pooling - Take attention mask into account for correct averaging
|
@@ -45,10 +45,11 @@ class ESGify(MPNetPreTrainedModel):
|
|
45 |
outputs = self.mpnet(input_ids=input_ids,
|
46 |
attention_mask=attention_mask)
|
47 |
|
48 |
-
# mean pooling dataset
|
49 |
logits = self.classifier( mean_pooling(outputs['last_hidden_state'],attention_mask))
|
50 |
-
|
51 |
-
|
|
|
52 |
return logits
|
53 |
|
54 |
model = ESGify.from_pretrained('ai-lab/ESGify')
|
|
|
9 |
|
10 |
## Usage
|
11 |
```python
|
12 |
+
from collections import OrderedDict
|
13 |
from transformers import MPNetPreTrainedModel, MPNetModel, AutoTokenizer
|
14 |
import torch
|
15 |
#Mean Pooling - Take attention mask into account for correct averaging
|
|
|
45 |
outputs = self.mpnet(input_ids=input_ids,
|
46 |
attention_mask=attention_mask)
|
47 |
|
48 |
+
# mean pooling dataset and eed input to classifier to compute logits
|
49 |
logits = self.classifier( mean_pooling(outputs['last_hidden_state'],attention_mask))
|
50 |
+
|
51 |
+
# apply sigmoid
|
52 |
+
logits = 1.0 / (1.0 + torch.exp(-logits))
|
53 |
return logits
|
54 |
|
55 |
model = ESGify.from_pretrained('ai-lab/ESGify')
|