|
import gradio as gr
|
|
import logging
|
|
import time
|
|
from generator.compute_metrics import get_attributes_text
|
|
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
|
|
from config import AppConfig, ConfigConstants
|
|
from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
|
|
from generator.document_utils import get_logs, initialize_logging
|
|
from retriever.load_selected_datasets import load_selected_datasets
|
|
|
|
def launch_gradio(config : AppConfig):
|
|
"""
|
|
Launch the Gradio app with pre-initialized objects.
|
|
"""
|
|
initialize_logging()
|
|
|
|
|
|
config.detect_loaded_datasets()
|
|
|
|
def update_logs_periodically():
|
|
while True:
|
|
time.sleep(2)
|
|
yield get_logs()
|
|
|
|
def answer_question(query, state):
|
|
try:
|
|
|
|
if config.vector_store is None:
|
|
return "Please load a dataset first.", state
|
|
|
|
|
|
response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
|
|
|
|
|
|
state["query"] = query
|
|
state["response"] = response
|
|
state["source_docs"] = source_docs
|
|
|
|
response_text = f"Response from Model ({config.gen_llm.name}) : {response}\n\n"
|
|
return response_text, state
|
|
except Exception as e:
|
|
logging.error(f"Error processing query: {e}")
|
|
return f"An error occurred: {e}", state
|
|
|
|
def compute_metrics(state):
|
|
try:
|
|
logging.info(f"Computing metrics")
|
|
|
|
|
|
response = state.get("response", "")
|
|
source_docs = state.get("source_docs", {})
|
|
query = state.get("query", "")
|
|
|
|
|
|
attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
|
|
|
|
attributes_text = get_attributes_text(attributes)
|
|
|
|
metrics_text = ""
|
|
for key, value in metrics.items():
|
|
if key != 'response':
|
|
metrics_text += f"{key}: {value}\n"
|
|
|
|
return attributes_text, metrics_text
|
|
except Exception as e:
|
|
logging.error(f"Error computing metrics: {e}")
|
|
return f"An error occurred: {e}", ""
|
|
|
|
def reinitialize_llm(model_type, model_name):
|
|
"""Reinitialize the specified LLM (generation or validation) and return updated model info."""
|
|
if model_name.strip():
|
|
if model_type == "generation":
|
|
config.gen_llm = initialize_generation_llm(model_name)
|
|
elif model_type == "validation":
|
|
config.val_llm = initialize_validation_llm(model_name)
|
|
|
|
return get_updated_model_info()
|
|
|
|
def get_updated_model_info():
|
|
loaded_datasets_str = ", ".join(config.loaded_datasets) if config.loaded_datasets else "None"
|
|
"""Generate and return the updated model information string."""
|
|
return (
|
|
f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
|
|
f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
|
|
f"Re-ranking LLM: {ConfigConstants.RE_RANKER_MODEL_NAME}\n"
|
|
f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
|
|
f"Loaded Datasets: {loaded_datasets_str}\n"
|
|
)
|
|
|
|
|
|
def reinitialize_gen_llm(gen_llm_name):
|
|
return reinitialize_llm("generation", gen_llm_name)
|
|
|
|
def reinitialize_val_llm(val_llm_name):
|
|
return reinitialize_llm("validation", val_llm_name)
|
|
|
|
|
|
def update_query_input(selected_question):
|
|
return selected_question
|
|
|
|
|
|
with gr.Blocks() as interface:
|
|
interface.title = "Real Time RAG Pipeline Q&A"
|
|
gr.Markdown("""
|
|
# Real Time RAG Pipeline Q&A
|
|
The **Retrieval-Augmented Generation (RAG) Pipeline** combines retrieval-based and generative AI models to provide accurate and context-aware answers to your questions.
|
|
It retrieves relevant documents from a dataset (e.g., COVIDQA, TechQA, FinQA) and uses a generative model to synthesize a response.
|
|
Metrics are computed to evaluate the quality of the response and the retrieval process.
|
|
""")
|
|
|
|
with gr.Accordion("System Information", open=False):
|
|
with gr.Accordion("DataSet", open=False):
|
|
with gr.Row():
|
|
dataset_selector = gr.CheckboxGroup(ConfigConstants.DATA_SET_NAMES, label="Select Datasets to Load")
|
|
load_button = gr.Button("Load", scale= 0)
|
|
|
|
with gr.Row():
|
|
|
|
with gr.Column(scale=1):
|
|
new_gen_llm_input = gr.Dropdown(
|
|
label="Generation Model",
|
|
choices=ConfigConstants.GENERATION_MODELS,
|
|
value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None,
|
|
interactive=True,
|
|
info="Select the generative model for response generation."
|
|
)
|
|
|
|
|
|
with gr.Column(scale=1):
|
|
new_val_llm_input = gr.Dropdown(
|
|
label="Validation Model",
|
|
choices=ConfigConstants.VALIDATION_MODELS,
|
|
value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None,
|
|
interactive=True,
|
|
info="Select the model for validating the response quality."
|
|
)
|
|
|
|
|
|
with gr.Column(scale=2):
|
|
model_info_display = gr.Textbox(
|
|
value=get_updated_model_info(),
|
|
label="Model Configuration",
|
|
interactive=False,
|
|
lines=5
|
|
)
|
|
|
|
|
|
gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.")
|
|
all_questions = [
|
|
"Does the ignition button have multiple modes?",
|
|
"Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover.",
|
|
"Is one party required to deposit its source code into escrow with a third party, which can be released to the counterparty upon the occurrence of certain events (bankruptcy, insolvency, etc.)?",
|
|
"Explain the concept of blockchain.",
|
|
"What is the capital of France?",
|
|
"Do Surface Porosity and Pore Size Influence Mechanical Properties and Cellular Response to PEEK??",
|
|
"How does a vaccine work?",
|
|
"Tell me the step-by-step instruction for front-door installation.",
|
|
"What are the risk factors for heart disease?",
|
|
"What is the % change in total property and equipment from 2018 to 2019?",
|
|
|
|
]
|
|
|
|
|
|
example_questions = [
|
|
"When was the first case of COVID-19 identified?",
|
|
"What are the ages of the patients in this study?",
|
|
"Why cant I load and AEL when using IE 11 JRE 8 Application Blocked by Java Security",
|
|
"Explain the concept of blockchain.",
|
|
"What is the capital of France?",
|
|
"What was the change in Current deferred income?"
|
|
]
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Row():
|
|
query_input = gr.Textbox(
|
|
label="Ask a question ",
|
|
placeholder="Type your query here or select from examples/dropdown",
|
|
lines=2
|
|
)
|
|
with gr.Row():
|
|
submit_button = gr.Button("Submit", variant="primary", scale=0)
|
|
clear_query_button = gr.Button("Clear", scale=0)
|
|
with gr.Column():
|
|
gr.Examples(
|
|
examples=example_questions,
|
|
inputs=query_input,
|
|
label="Try these examples:"
|
|
)
|
|
question_dropdown = gr.Dropdown(
|
|
label="",
|
|
choices=all_questions,
|
|
interactive=True,
|
|
info="Choose a question from the dropdown to populate the query box."
|
|
)
|
|
|
|
|
|
question_dropdown.change(
|
|
fn=update_query_input,
|
|
inputs=question_dropdown,
|
|
outputs=query_input
|
|
)
|
|
|
|
|
|
with gr.Row():
|
|
answer_output = gr.Textbox(label="Response", placeholder="Response will appear here", lines=2)
|
|
|
|
with gr.Row():
|
|
compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
|
|
attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
|
|
metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
|
|
|
|
|
|
state = gr.State(value={"query": "","response": "", "source_docs": {}})
|
|
|
|
|
|
load_button.click(lambda datasets: (load_selected_datasets(datasets, config), get_updated_model_info()), inputs=dataset_selector, outputs=model_info_display)
|
|
|
|
new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
|
|
new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)
|
|
|
|
|
|
submit_button.click(
|
|
fn=answer_question,
|
|
inputs=[query_input, state],
|
|
outputs=[answer_output, state]
|
|
)
|
|
clear_query_button.click(fn=lambda: "", outputs=[query_input])
|
|
compute_metrics_button.click(
|
|
fn=compute_metrics,
|
|
inputs=[state],
|
|
outputs=[attr_output, metrics_output]
|
|
)
|
|
|
|
|
|
with gr.Accordion("View Live Logs", open=False):
|
|
with gr.Row():
|
|
log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2)
|
|
|
|
|
|
interface.queue()
|
|
interface.load(update_logs_periodically, outputs=log_section)
|
|
interface.load(get_updated_model_info, outputs=model_info_display)
|
|
|
|
interface.launch() |