douglasfaisal commited on
Commit
72cacc7
·
1 Parent(s): 88b4edc

Add sentence-transformers to requirements

Browse files
Files changed (3) hide show
  1. main.py +0 -38
  2. requirements.txt +1 -0
  3. reranker/reranker.py +1 -1
main.py DELETED
@@ -1,38 +0,0 @@
1
- # This is a sample Python script.
2
-
3
- # Press Shift+F10 to execute it or replace it with your code.
4
- # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
5
-
6
- import gradio as gr
7
- import os
8
-
9
- from reranker.reranker import CrossEncReranker
10
- from retriever.es_retriever import ESRetriever
11
- from utils.preprocessing import question_to_statement
12
-
13
-
14
- ES_HOST = os.environ["ES_HOST"]
15
- ES_INDEX_NAME = os.environ["ES_INDEX_NAME"]
16
- ES_USERNAME = os.environ["ES_USERNAME"]
17
- ES_PASSWORD = os.environ["ES_PASSWORD"]
18
-
19
- RERANKER_MODEL_NAME = "douglasfaisal/granularity-legal-reranker-cross-encoder-indobert-base-p2"
20
-
21
- es_retriever_client = ESRetriever(ES_HOST, ES_INDEX_NAME, ES_USERNAME, ES_PASSWORD)
22
- cross_enc_reranker = CrossEncReranker(RERANKER_MODEL_NAME, 512)
23
-
24
- def retrieve_and_rerank(question: str):
25
-
26
- query = question_to_statement(question)
27
- retrieval_results = es_retriever_client.retrieve(query)
28
- reranker_results = cross_enc_reranker.rerank(query, retrieval_results)
29
-
30
- return reranker_results[0].text
31
-
32
-
33
- demo = gr.Interface(fn=retrieve_and_rerank, inputs="text", outputs="text")
34
-
35
- # Press the green button in the gutter to run the script.
36
- demo.launch()
37
-
38
- # See PyCharm help at https://www.jetbrains.com/help/pycharm/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio~=3.28.1
2
  numpy==1.21.4
3
  requests==2.26.0
 
 
1
  gradio~=3.28.1
2
  numpy==1.21.4
3
  requests==2.26.0
4
+ sentence-transformers
reranker/reranker.py CHANGED
@@ -10,7 +10,7 @@ class CrossEncReranker:
10
  self.reranker = CrossEncoder(self.model_name)
11
  self.reranker.max_length = max_length
12
 
13
- def rerank(self, query_text: str, candidates: list[LawComponent]):
14
  sentence_combinations = [[query_text, c.text] for c in candidates]
15
  similarity_scores = self.reranker.predict(sentence_combinations)
16
  index = np.argsort(similarity_scores)[::-1]
 
10
  self.reranker = CrossEncoder(self.model_name)
11
  self.reranker.max_length = max_length
12
 
13
+ def rerank(self, query_text: str, candidates: list):
14
  sentence_combinations = [[query_text, c.text] for c in candidates]
15
  similarity_scores = self.reranker.predict(sentence_combinations)
16
  index = np.argsort(similarity_scores)[::-1]