playmak3r commited on
Commit
00933b9
·
0 Parent(s):

initial commit

Browse files
Files changed (4) hide show
  1. requirements.txt +5 -0
  2. server.py +36 -0
  3. similarity.py +66 -0
  4. tests/test.py +9 -0
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sacrebleu
2
+ torch
3
+ sentence_transformers
4
+ fastapi
5
+ uvicorn
server.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from fastapi import FastAPI, Request
3
+ from fastapi.responses import RedirectResponse
4
+ from pydantic import BaseModel
5
+ from similarity import get_similarity_batched, get_bleu, get_chrf
6
+
7
+
8
+ app = FastAPI(
9
+ title="Sentence similarity API",
10
+ description="Check Sentences similarities.",
11
+ version="1.0"
12
+ )
13
+
14
+
15
+ class Texts(BaseModel):
16
+ texts1: List[str]
17
+ texts2: List[str]
18
+
19
+ @app.get("/")
20
+ def home():
21
+ #return {"mensagem": "Bem-vindo à API!"}
22
+ return RedirectResponse(url="/docs")
23
+
24
+
25
+ @app.post('/api/similarity')
26
+ def get_sim(texts: Texts):
27
+ result = []
28
+ sim = get_similarity_batched(texts.texts1, texts.texts2)
29
+ for i in range(0, len(texts.texts1)):
30
+ result.append({
31
+ "bleu": get_bleu(texts.texts1[i], texts.texts2[i]),
32
+ "chrf": get_chrf(texts.texts1[i], texts.texts2[i]),
33
+ "similarity": sim[i]
34
+ })
35
+
36
+ return result
similarity.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sacrebleu
2
+ import re
3
+ from typing import List
4
+
5
+
6
+ st_model = None
7
+ def get_similarity_batched(texts1: List[str], texts2: List[str]):
8
+ import torch
9
+ from sentence_transformers import SentenceTransformer, util
10
+ global st_model
11
+ if st_model is None:
12
+ #paraphrase-multilingual-mpnet-base-v2
13
+ #all-MiniLM-L12-v2
14
+ #all-distilroberta-v1
15
+ #all-mpnet-base-v2
16
+ #all-MiniLM-L6-v2
17
+ st_model = SentenceTransformer('all-mpnet-base-v2', device='cuda' if torch.cuda.is_available() else 'cpu', cache_folder="./s_cache")
18
+
19
+ clean_text_batch(texts1, texts2)
20
+ embeddings1 = st_model.encode(texts1, convert_to_tensor=True, show_progress_bar=False)
21
+ embeddings2 = st_model.encode(texts2, convert_to_tensor=True, show_progress_bar=False)
22
+ cosine_scores = util.cos_sim(embeddings1, embeddings2)
23
+ return cosine_scores.diag()
24
+
25
+ def clean_text_batch(texts1: List[str], texts2: List[str]):
26
+ if len(texts1) == len(texts2):
27
+ for i in range(0, len(texts1)):
28
+ texts1[i] = clean_text(texts1[i], stricter= True)
29
+ texts2[i] = clean_text(texts2[i], stricter= True)
30
+ #
31
+
32
+
33
+ def clean_text(text, stricter=False):
34
+ if stricter:
35
+ text = re.sub(r"([^a-zA-Z]|^)([a-zA-Z])(?i:-\2)+([a-zA-Z])", r"\1\2\3", text)
36
+ to_strip = "&っ。~―()「」「」『』“”\"',、○()«»~ \t\r\n"
37
+ if stricter:
38
+ to_strip += "….??!!,"
39
+ text = text.strip(to_strip)
40
+ return text
41
+
42
+ def get_similarity(ref, hyp):
43
+ ref = clean_text(ref, stricter=True)
44
+ if not ref:
45
+ return 1.0
46
+ hyp = clean_text(hyp, stricter=True)
47
+ if ref.lower() == hyp.lower():
48
+ return 1.0
49
+ return float(get_similarity_batched([ref], [hyp])[0])
50
+
51
+ def get_bleu(ref, hyp):
52
+ ref = clean_text(ref)
53
+ hyp = clean_text(hyp)
54
+ if ref.lower() == hyp.lower():
55
+ return 100
56
+ bleu = sacrebleu.sentence_bleu(hyp, [ref])
57
+ return bleu.score
58
+
59
+ def get_chrf(ref, hyp):
60
+ ref = clean_text(ref)
61
+ hyp = clean_text(hyp)
62
+ if ref.lower() == hyp.lower():
63
+ return 100
64
+ chrf = sacrebleu.sentence_chrf(hyp, [ref])
65
+ return chrf.score
66
+
tests/test.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+
4
+ response = requests.post("http://localhost:8000/api/similarity", json={
5
+ "texts1": ["Eu gosto de andar de bicicleta nas manhãs de domingo.", "A entrega está programada para amanhã à tarde."],
6
+ "texts2": ["Aos domingos de manhã, eu adoro pedalar.", "A remessa vai chegar amanhã no período da tarde."],
7
+ })
8
+
9
+ print(response.json())