|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import os |
|
|
|
from reranker.reranker import CrossEncReranker |
|
from retriever.es_retriever import ESRetriever |
|
from utils.preprocessing import question_to_statement |
|
|
|
|
|
ES_HOST = os.environ["ES_HOST"] |
|
ES_INDEX_NAME = os.environ["ES_INDEX_NAME"] |
|
ES_USERNAME = os.environ["ES_USERNAME"] |
|
ES_PASSWORD = os.environ["ES_PASSWORD"] |
|
|
|
RERANKER_MODEL_NAME = "douglasfaisal/granularity-legal-reranker-cross-encoder-indobert-base-p2" |
|
|
|
es_retriever_client = ESRetriever(ES_HOST, ES_INDEX_NAME, ES_USERNAME, ES_PASSWORD) |
|
cross_enc_reranker = CrossEncReranker(RERANKER_MODEL_NAME, 512) |
|
|
|
def retrieve_and_rerank(question: str): |
|
|
|
query = question_to_statement(question) |
|
retrieval_results = es_retriever_client.retrieve(query) |
|
reranker_results = cross_enc_reranker.rerank(query, retrieval_results) |
|
|
|
return reranker_results[0].text |
|
|
|
|
|
demo = gr.Interface(fn=retrieve_and_rerank, inputs="text", outputs="text") |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|