|
This is finetune version of [SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://arxiv.org/abs/2104.08821) |
|
|
|
- Train supervised on 100K triplet samples samples related to stroke domain from : stroke books, quora medical, quora's stroke, quora's general and human annotates. |
|
- Positive sentences are generated by paraphrasing and back-translate. |
|
- Negative sentences are randomly selected in general domain. |
|
|
|
### Extract sentence representation |
|
``` |
|
from transformers import AutoTokenizer, AutoModel |
|
tokenizer = AutoTokenizer.from_pretrained("demdecuong/stroke_sup_simcse") |
|
model = AutoModel.from_pretrained("demdecuong/stroke_sup_simcse") |
|
|
|
text = "What are disease related to red stroke's causes?" |
|
inputs = tokenizer(text, return_tensors='pt') |
|
outputs = model(**inputs)[1] |
|
``` |
|
### Build up embedding for database |
|
|
|
``` |
|
database = [ |
|
'What is the daily checklist for stroke returning home', |
|
'What are some tips for stroke adapt new life', |
|
'What should I consider when using nursing-home care' |
|
] |
|
|
|
embedding = torch.zeros((len(database),768)) |
|
|
|
for i in range(len(database)): |
|
inputs = tokenizer(database[i], return_tensors="pt") |
|
outputs = model(**inputs)[1] |
|
embedding[i] = outputs |
|
|
|
print(embedding.shape) |
|
``` |
|
|
|
### Result |
|
On our company's PoC project, the testset contains positive/negative pairs of matching question related to stroke from human-generation. |
|
- SimCSE supervised + 100k : Train on 100K triplet samples contains : medical, stroke and general domain |
|
- SimCSE supervised + 42k : Train on 42K triplet samples contains : medical, stroke domain |
|
|
|
|
|
| Model | Top-1 Accuracy | |
|
| ------------- | ------------- | |
|
| SimCSE supervised (author) | 75.83 | |
|
| SimCSE unsupervised (ours) | 76.66 | |
|
| SimCSE supervised + 100k (ours) | 73.33 | |
|
| SimCSE supervised + 42k (ours) | 75.83 | |