Spaces:
Runtime error
Runtime error
Commit
·
00933b9
0
Parent(s):
initial commit
Browse files- requirements.txt +5 -0
- server.py +36 -0
- similarity.py +66 -0
- tests/test.py +9 -0
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sacrebleu
|
2 |
+
torch
|
3 |
+
sentence_transformers
|
4 |
+
fastapi
|
5 |
+
uvicorn
|
server.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from fastapi import FastAPI, Request
|
3 |
+
from fastapi.responses import RedirectResponse
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from similarity import get_similarity_batched, get_bleu, get_chrf
|
6 |
+
|
7 |
+
|
8 |
+
app = FastAPI(
|
9 |
+
title="Sentence similarity API",
|
10 |
+
description="Check Sentences similarities.",
|
11 |
+
version="1.0"
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class Texts(BaseModel):
|
16 |
+
texts1: List[str]
|
17 |
+
texts2: List[str]
|
18 |
+
|
19 |
+
@app.get("/")
|
20 |
+
def home():
|
21 |
+
#return {"mensagem": "Bem-vindo à API!"}
|
22 |
+
return RedirectResponse(url="/docs")
|
23 |
+
|
24 |
+
|
25 |
+
@app.post('/api/similarity')
|
26 |
+
def get_sim(texts: Texts):
|
27 |
+
result = []
|
28 |
+
sim = get_similarity_batched(texts.texts1, texts.texts2)
|
29 |
+
for i in range(0, len(texts.texts1)):
|
30 |
+
result.append({
|
31 |
+
"bleu": get_bleu(texts.texts1[i], texts.texts2[i]),
|
32 |
+
"chrf": get_chrf(texts.texts1[i], texts.texts2[i]),
|
33 |
+
"similarity": sim[i]
|
34 |
+
})
|
35 |
+
|
36 |
+
return result
|
similarity.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sacrebleu
|
2 |
+
import re
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
st_model = None
|
7 |
+
def get_similarity_batched(texts1: List[str], texts2: List[str]):
|
8 |
+
import torch
|
9 |
+
from sentence_transformers import SentenceTransformer, util
|
10 |
+
global st_model
|
11 |
+
if st_model is None:
|
12 |
+
#paraphrase-multilingual-mpnet-base-v2
|
13 |
+
#all-MiniLM-L12-v2
|
14 |
+
#all-distilroberta-v1
|
15 |
+
#all-mpnet-base-v2
|
16 |
+
#all-MiniLM-L6-v2
|
17 |
+
st_model = SentenceTransformer('all-mpnet-base-v2', device='cuda' if torch.cuda.is_available() else 'cpu', cache_folder="./s_cache")
|
18 |
+
|
19 |
+
clean_text_batch(texts1, texts2)
|
20 |
+
embeddings1 = st_model.encode(texts1, convert_to_tensor=True, show_progress_bar=False)
|
21 |
+
embeddings2 = st_model.encode(texts2, convert_to_tensor=True, show_progress_bar=False)
|
22 |
+
cosine_scores = util.cos_sim(embeddings1, embeddings2)
|
23 |
+
return cosine_scores.diag()
|
24 |
+
|
25 |
+
def clean_text_batch(texts1: List[str], texts2: List[str]):
|
26 |
+
if len(texts1) == len(texts2):
|
27 |
+
for i in range(0, len(texts1)):
|
28 |
+
texts1[i] = clean_text(texts1[i], stricter= True)
|
29 |
+
texts2[i] = clean_text(texts2[i], stricter= True)
|
30 |
+
#
|
31 |
+
|
32 |
+
|
33 |
+
def clean_text(text, stricter=False):
|
34 |
+
if stricter:
|
35 |
+
text = re.sub(r"([^a-zA-Z]|^)([a-zA-Z])(?i:-\2)+([a-zA-Z])", r"\1\2\3", text)
|
36 |
+
to_strip = "&っ。~―()「」「」『』“”\"',、○()«»~ \t\r\n"
|
37 |
+
if stricter:
|
38 |
+
to_strip += "….??!!,"
|
39 |
+
text = text.strip(to_strip)
|
40 |
+
return text
|
41 |
+
|
42 |
+
def get_similarity(ref, hyp):
|
43 |
+
ref = clean_text(ref, stricter=True)
|
44 |
+
if not ref:
|
45 |
+
return 1.0
|
46 |
+
hyp = clean_text(hyp, stricter=True)
|
47 |
+
if ref.lower() == hyp.lower():
|
48 |
+
return 1.0
|
49 |
+
return float(get_similarity_batched([ref], [hyp])[0])
|
50 |
+
|
51 |
+
def get_bleu(ref, hyp):
|
52 |
+
ref = clean_text(ref)
|
53 |
+
hyp = clean_text(hyp)
|
54 |
+
if ref.lower() == hyp.lower():
|
55 |
+
return 100
|
56 |
+
bleu = sacrebleu.sentence_bleu(hyp, [ref])
|
57 |
+
return bleu.score
|
58 |
+
|
59 |
+
def get_chrf(ref, hyp):
|
60 |
+
ref = clean_text(ref)
|
61 |
+
hyp = clean_text(hyp)
|
62 |
+
if ref.lower() == hyp.lower():
|
63 |
+
return 100
|
64 |
+
chrf = sacrebleu.sentence_chrf(hyp, [ref])
|
65 |
+
return chrf.score
|
66 |
+
|
tests/test.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
|
4 |
+
response = requests.post("http://localhost:8000/api/similarity", json={
|
5 |
+
"texts1": ["Eu gosto de andar de bicicleta nas manhãs de domingo.", "A entrega está programada para amanhã à tarde."],
|
6 |
+
"texts2": ["Aos domingos de manhã, eu adoro pedalar.", "A remessa vai chegar amanhã no período da tarde."],
|
7 |
+
})
|
8 |
+
|
9 |
+
print(response.json())
|