|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import os |
|
import pandas as pd |
|
|
|
from reranker.reranker import CrossEncReranker |
|
from retriever.es_retriever import ESRetriever |
|
from utils.preprocessing import question_to_statement |
|
|
|
|
|
ES_HOST = os.environ["ES_HOST"] |
|
ES_INDEX_NAME = os.environ["ES_INDEX_NAME"] |
|
ES_USERNAME = os.environ["ES_USERNAME"] |
|
ES_PASSWORD = os.environ["ES_PASSWORD"] |
|
|
|
RERANKER_MODEL_NAME = "douglasfaisal/granularity-legal-reranker-cross-encoder-indobert-base-p2" |
|
|
|
es_retriever_client = ESRetriever(ES_HOST, ES_INDEX_NAME, ES_USERNAME, ES_PASSWORD) |
|
cross_enc_reranker = CrossEncReranker(RERANKER_MODEL_NAME, 512) |
|
|
|
def retrieve_and_rerank(question: str, example: str): |
|
|
|
if (question == None or question == ""): |
|
question = example |
|
query = question_to_statement(question) |
|
|
|
try: |
|
retrieval_results = es_retriever_client.retrieve(query) |
|
reranker_results = cross_enc_reranker.rerank(query, retrieval_results) |
|
|
|
law_refs = [i.generate_string() for i in reranker_results] |
|
law_texts = [i.text for i in reranker_results] |
|
|
|
df = pd.DataFrame({ |
|
'Rank': range(1, len(law_refs)+1), |
|
'Reference': law_refs, |
|
'Text': law_texts |
|
}) |
|
|
|
return reranker_results[0].generate_string(), reranker_results[0].text, df |
|
except: |
|
return "-", "(Result Not Found)" |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
text_input = gr.Textbox() |
|
|
|
demo = gr.Interface( |
|
fn=retrieve_and_rerank, |
|
inputs=[ |
|
"text", |
|
gr.Dropdown( |
|
[ |
|
"Apa yang dimaksud dengan pemberi kerja?", |
|
"Berapa paling lama waktu kerja lembur?", |
|
"Apa bentuk pendapatan non-upah?" |
|
] |
|
) |
|
], |
|
outputs=[ |
|
"label", |
|
"text", |
|
"dataframe" |
|
]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|