Gourisankar Padihary
Added Gradio UI
e234b58
raw
history blame
1.74 kB
import logging
from data.load_dataset import load_data
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
from retriever.chunk_documents import chunk_documents
from retriever.embed_documents import embed_documents
from generator.generate_metrics import generate_metrics
from generator.initialize_llm import initialize_generation_llm
from generator.initialize_llm import initialize_validation_llm
from app import launch_gradio
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main():
logging.info("Starting the RAG pipeline")
data_set_name = 'covidqa'
# Load the dataset
dataset = load_data(data_set_name)
logging.info("Dataset loaded")
# Chunk the dataset
chunk_size = 1000 # default value
if data_set_name == 'cuad':
chunk_size = 3000
documents = chunk_documents(dataset, chunk_size)
logging.info("Documents chunked")
# Embed the documents
vector_store = embed_documents(documents)
logging.info("Documents embedded")
# Initialize the Generation LLM
gen_llm = initialize_generation_llm()
# Initialize the Validation LLM
val_llm = initialize_validation_llm()
# Sample question
#row_num = 30
#query = dataset[row_num]['question']
# Call generate_metrics for above sample question
#generate_metrics(gen_llm, val_llm, vector_store, query)
#Compute RMSE and AUC-ROC for entire dataset
#compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
# Launch the Gradio app
launch_gradio(vector_store, dataset, gen_llm, val_llm)
logging.info("Finished!!!")
if __name__ == "__main__":
main()