File size: 1,828 Bytes
c2918b6 28e6817 c2918b6 0e09914 c2918b6 f4e85a4 c2918b6 f4e85a4 c2918b6 f4e85a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
This is finetune version of [SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://arxiv.org/abs/2104.08821)
- Train supervised on 100K triplet samples samples related to stroke domain from : stroke books, quora medical, quora's stroke, quora's general and human annotates.
- Positive sentences are generated by paraphrasing and back-translate.
- Negative sentences are randomly selected in general domain.
### Extract sentence representation
```
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("demdecuong/stroke_sup_simcse")
model = AutoModel.from_pretrained("demdecuong/stroke_sup_simcse")
text = "What are disease related to red stroke's causes?"
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs)[1]
```
### Build up embedding for database
```
database = [
'What is the daily checklist for stroke returning home',
'What are some tips for stroke adapt new life',
'What should I consider when using nursing-home care'
]
embedding = torch.zeros((len(database),768))
for i in range(len(database)):
inputs = tokenizer(database[i], return_tensors="pt")
outputs = model(**inputs)[1]
embedding[i] = outputs
print(embedding.shape)
```
### Result
On our company's PoC project, the testset contains positive/negative pairs of matching question related to stroke from human-generation.
- SimCSE supervised + 100k : Train on 100K triplet samples contains : medical, stroke and general domain
- SimCSE supervised + 42k : Train on 42K triplet samples contains : medical, stroke domain
| Model | Top-1 Accuracy |
| ------------- | ------------- |
| SimCSE supervised (author) | 75.83 |
| SimCSE unsupervised (ours) | 76.66 |
| SimCSE supervised + 100k (ours) | 73.33 |
| SimCSE supervised + 42k (ours) | 75.83 | |