Update README.md
Browse files
README.md
CHANGED
@@ -10,4 +10,81 @@ and first released in [this repository](https://github.com/DAMO-NLP-SG/SSTuning)
|
|
10 |
|
11 |
The model backbone is albert-xxlarge-v2.
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
The model backbone is albert-xxlarge-v2.
|
12 |
|
13 |
+
## Model description
|
14 |
+
The model is tuned with unlabeled data using a learning objective called first sentence prediction (FSP).
|
15 |
+
The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks.
|
16 |
+
The training and validation sets are constructed from the unlabeled corpus using FSP.
|
17 |
+
|
18 |
+
During tuning, BERT-like pre-trained masked language
|
19 |
+
models such as RoBERTa and ALBERT are employed as the backbone, and an output layer for classification is added.
|
20 |
+
The learning objective for FSP is to predict the index of the correct label.
|
21 |
+
A cross-entropy loss is used for tuning the model.
|
22 |
+
|
23 |
+
## Model variations
|
24 |
+
There are three versions of models released. The details are:
|
25 |
+
|
26 |
+
| Model | Backbone | #params | accuracy | Speed | #Training data
|
27 |
+
|------------|-----------|----------|-------|-------|----|
|
28 |
+
| [zero-shot-classify-SSTuning-base](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-base) | [roberta-base](https://huggingface.co/roberta-base) | 125M | Low | High | 20.48M |
|
29 |
+
| [zero-shot-classify-SSTuning-large](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-large) | [roberta-large](https://huggingface.co/roberta-large) | 355M | Medium | Medium | 5.12M |
|
30 |
+
| [zero-shot-classify-SSTuning-ALBERT](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT) | [albert-xxlarge-v2](https://huggingface.co/albert-xxlarge-v2) | 235M | High | Low| 5.12M |
|
31 |
+
|
32 |
+
Please note that zero-shot-classify-SSTuning-base is trained with more data (20.48M) than the paper, as this will increase the accuracy.
|
33 |
+
|
34 |
+
|
35 |
+
## Intended uses & limitations
|
36 |
+
The model can be used for zero-shot text classification such as sentiment analysis and topic classification. No further finetuning is needed.
|
37 |
+
|
38 |
+
The number of labels should be 2 ~ 20.
|
39 |
+
|
40 |
+
### How to use
|
41 |
+
You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
|
42 |
+
|
43 |
+
```python
|
44 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
45 |
+
import torch, string, random
|
46 |
+
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
48 |
+
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
49 |
+
|
50 |
+
text = "I love this place! The food is always so fresh and delicious."
|
51 |
+
list_label = ["negative", "positive"]
|
52 |
+
|
53 |
+
list_ABC = [x for x in string.ascii_uppercase]
|
54 |
+
def add_prefix(text, list_label, shuffle = False):
|
55 |
+
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
|
56 |
+
list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
|
57 |
+
if shuffle:
|
58 |
+
random.shuffle(list_label_new)
|
59 |
+
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
|
60 |
+
return f'{s_option} {tokenizer.sep_token} {text}', list_label_new
|
61 |
+
|
62 |
+
text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
|
63 |
+
|
64 |
+
encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
|
65 |
+
with torch.no_grad():
|
66 |
+
logits = model(**encoding).logits
|
67 |
+
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
|
68 |
+
predictions = torch.argmax(logits, dim=-1)
|
69 |
+
|
70 |
+
print(probs)
|
71 |
+
print(predictions)
|
72 |
+
```
|
73 |
+
|
74 |
+
|
75 |
+
### BibTeX entry and citation info
|
76 |
+
```bibtxt
|
77 |
+
@inproceedings{acl23/SSTuning,
|
78 |
+
author = {Chaoqun Liu and
|
79 |
+
Wenxuan Zhang and
|
80 |
+
Guizhen Chen and
|
81 |
+
Xiaobao Wu and
|
82 |
+
Anh Tuan Luu and
|
83 |
+
Chip Hong Chang and
|
84 |
+
Lidong Bing},
|
85 |
+
title = {Zero-Shot Text Classification via Self-Supervised Tuning},
|
86 |
+
booktitle = {Findings of the 2023 ACL},
|
87 |
+
year = {2023},
|
88 |
+
url = {},
|
89 |
+
}
|
90 |
+
```
|