Commit
·
72cacc7
1
Parent(s):
88b4edc
Add sentence-transformers to requirements
Browse files- main.py +0 -38
- requirements.txt +1 -0
- reranker/reranker.py +1 -1
main.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
# This is a sample Python script.
|
2 |
-
|
3 |
-
# Press Shift+F10 to execute it or replace it with your code.
|
4 |
-
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
5 |
-
|
6 |
-
import gradio as gr
|
7 |
-
import os
|
8 |
-
|
9 |
-
from reranker.reranker import CrossEncReranker
|
10 |
-
from retriever.es_retriever import ESRetriever
|
11 |
-
from utils.preprocessing import question_to_statement
|
12 |
-
|
13 |
-
|
14 |
-
ES_HOST = os.environ["ES_HOST"]
|
15 |
-
ES_INDEX_NAME = os.environ["ES_INDEX_NAME"]
|
16 |
-
ES_USERNAME = os.environ["ES_USERNAME"]
|
17 |
-
ES_PASSWORD = os.environ["ES_PASSWORD"]
|
18 |
-
|
19 |
-
RERANKER_MODEL_NAME = "douglasfaisal/granularity-legal-reranker-cross-encoder-indobert-base-p2"
|
20 |
-
|
21 |
-
es_retriever_client = ESRetriever(ES_HOST, ES_INDEX_NAME, ES_USERNAME, ES_PASSWORD)
|
22 |
-
cross_enc_reranker = CrossEncReranker(RERANKER_MODEL_NAME, 512)
|
23 |
-
|
24 |
-
def retrieve_and_rerank(question: str):
|
25 |
-
|
26 |
-
query = question_to_statement(question)
|
27 |
-
retrieval_results = es_retriever_client.retrieve(query)
|
28 |
-
reranker_results = cross_enc_reranker.rerank(query, retrieval_results)
|
29 |
-
|
30 |
-
return reranker_results[0].text
|
31 |
-
|
32 |
-
|
33 |
-
demo = gr.Interface(fn=retrieve_and_rerank, inputs="text", outputs="text")
|
34 |
-
|
35 |
-
# Press the green button in the gutter to run the script.
|
36 |
-
demo.launch()
|
37 |
-
|
38 |
-
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
gradio~=3.28.1
|
2 |
numpy==1.21.4
|
3 |
requests==2.26.0
|
|
|
|
1 |
gradio~=3.28.1
|
2 |
numpy==1.21.4
|
3 |
requests==2.26.0
|
4 |
+
sentence-transformers
|
reranker/reranker.py
CHANGED
@@ -10,7 +10,7 @@ class CrossEncReranker:
|
|
10 |
self.reranker = CrossEncoder(self.model_name)
|
11 |
self.reranker.max_length = max_length
|
12 |
|
13 |
-
def rerank(self, query_text: str, candidates: list
|
14 |
sentence_combinations = [[query_text, c.text] for c in candidates]
|
15 |
similarity_scores = self.reranker.predict(sentence_combinations)
|
16 |
index = np.argsort(similarity_scores)[::-1]
|
|
|
10 |
self.reranker = CrossEncoder(self.model_name)
|
11 |
self.reranker.max_length = max_length
|
12 |
|
13 |
+
def rerank(self, query_text: str, candidates: list):
|
14 |
sentence_combinations = [[query_text, c.text] for c in candidates]
|
15 |
similarity_scores = self.reranker.predict(sentence_combinations)
|
16 |
index = np.argsort(similarity_scores)[::-1]
|