rag-augment / app.py
davidberenstein1957's picture
Create app.py
502be4c verified
raw
history blame
1.36 kB
import gradio as gr
from sentence_transformers import CrossEncoder
import pandas as pd
reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2")
def rerank_documents(query: str, documents: pd.DataFrame) -> pd.DataFrame:
documents = documents.copy()
documents = documents.drop_duplicates("chunk")
documents["rank"] = reranker.predict([[query, hit] for hit in documents["chunk"]])
documents = documents.sort_values(by="rank", ascending=False)
return documents
with gr.Blocks() as demo:
gr.Markdown("""# RAG Hub Datasets
Part of [smol blueprint](https://github.com/davidberenstein1957/smol-blueprint) - a smol blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs.""")
query_input = gr.Textbox(
label="Query", placeholder="Enter your question here...", lines=3
)
documents_input = gr.Dataframe(
label="Documents", headers=["chunk"], wrap=True, interactive=True
)
submit_btn = gr.Button("Submit")
documents_output = gr.Dataframe(
label="Documents", headers=["chunk", "rank"], wrap=True
)
submit_btn.click(
fn=rerank_documents,
inputs=[query_input, documents_input],
outputs=[documents_output],
)
demo.launch()