This is finetune version of [SimCSE: Simple Contrastive Learning of Sentence Embeddings](https://arxiv.org/abs/2104.08821) - Train supervised on 100K triplet samples samples related to stroke domain from : stroke books, quora medical, quora's stroke, quora's general and human annotates. - Positive sentences are generated by paraphrasing and back-translate. - Negative sentences are randomly selected in general domain. ### Extract sentence representation ``` from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("demdecuong/stroke_sup_simcse") model = AutoModel.from_pretrained("demdecuong/stroke_sup_simcse") text = "What are disease related to red stroke's causes?" inputs = tokenizer(text, return_tensors='pt') outputs = model(**inputs)[1] ``` ### Build up embedding for database ``` database = [ 'What is the daily checklist for stroke returning home', 'What are some tips for stroke adapt new life', 'What should I consider when using nursing-home care' ] embedding = torch.zeros((len(database),768)) for i in range(len(database)): inputs = tokenizer(database[i], return_tensors="pt") outputs = model(**inputs)[1] embedding[i] = outputs print(embedding.shape) ``` ### Result On our company's PoC project, the testset contains positive/negative pairs of matching question related to stroke from human-generation. - SimCSE supervised + 100k : Train on 100K triplet samples contains : medical, stroke and general domain - SimCSE supervised + 42k : Train on 42K triplet samples contains : medical, stroke domain | Model | Top-1 Accuracy | | ------------- | ------------- | | SimCSE supervised (author) | 75.83 | | SimCSE unsupervised (ours) | 76.66 | | SimCSE supervised + 100k (ours) | 73.33 | | SimCSE supervised + 42k (ours) | 75.83 |