File size: 1,738 Bytes
f7c2fa3 026aeba e384879 026aeba e384879 9bde774 e234b58 bd69eee 026aeba bd69eee bcc15bd bd69eee 026aeba b58a992 bd69eee 026aeba b58a992 bd69eee 026aeba bd69eee 79dcf63 e384879 9bde774 bd69eee e384879 e234b58 026aeba e384879 bcc15bd f7c2fa3 e234b58 e384879 e234b58 e384879 026aeba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import logging
from data.load_dataset import load_data
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
from retriever.chunk_documents import chunk_documents
from retriever.embed_documents import embed_documents
from generator.generate_metrics import generate_metrics
from generator.initialize_llm import initialize_generation_llm
from generator.initialize_llm import initialize_validation_llm
from app import launch_gradio
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main():
logging.info("Starting the RAG pipeline")
data_set_name = 'covidqa'
# Load the dataset
dataset = load_data(data_set_name)
logging.info("Dataset loaded")
# Chunk the dataset
chunk_size = 1000 # default value
if data_set_name == 'cuad':
chunk_size = 3000
documents = chunk_documents(dataset, chunk_size)
logging.info("Documents chunked")
# Embed the documents
vector_store = embed_documents(documents)
logging.info("Documents embedded")
# Initialize the Generation LLM
gen_llm = initialize_generation_llm()
# Initialize the Validation LLM
val_llm = initialize_validation_llm()
# Sample question
#row_num = 30
#query = dataset[row_num]['question']
# Call generate_metrics for above sample question
#generate_metrics(gen_llm, val_llm, vector_store, query)
#Compute RMSE and AUC-ROC for entire dataset
#compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
# Launch the Gradio app
launch_gradio(vector_store, dataset, gen_llm, val_llm)
logging.info("Finished!!!")
if __name__ == "__main__":
main() |