File size: 1,738 Bytes
f7c2fa3
026aeba
e384879
026aeba
 
e384879
9bde774
 
e234b58
bd69eee
 
 
026aeba
 
bd69eee
bcc15bd
bd69eee
026aeba
b58a992
bd69eee
 
026aeba
b58a992
 
 
 
bd69eee
 
026aeba
 
bd69eee
79dcf63
e384879
9bde774
 
 
 
bd69eee
e384879
e234b58
 
026aeba
e384879
bcc15bd
f7c2fa3
 
e234b58
e384879
e234b58
 
 
e384879
 
026aeba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
from data.load_dataset import load_data
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
from retriever.chunk_documents import chunk_documents
from retriever.embed_documents import embed_documents
from generator.generate_metrics import generate_metrics
from generator.initialize_llm import initialize_generation_llm
from generator.initialize_llm import initialize_validation_llm
from app import launch_gradio

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def main():
    logging.info("Starting the RAG pipeline")
    data_set_name = 'covidqa'

    # Load the dataset
    dataset = load_data(data_set_name)
    logging.info("Dataset loaded")

    # Chunk the dataset
    chunk_size = 1000  # default value
    if data_set_name == 'cuad':
        chunk_size = 3000
    documents = chunk_documents(dataset, chunk_size)
    logging.info("Documents chunked")

    # Embed the documents
    vector_store = embed_documents(documents)
    logging.info("Documents embedded")
    
     # Initialize the Generation LLM
    gen_llm = initialize_generation_llm()

    # Initialize the Validation LLM
    val_llm = initialize_validation_llm()

    # Sample question
    #row_num = 30
    #query = dataset[row_num]['question']

    # Call generate_metrics for above sample question
    #generate_metrics(gen_llm, val_llm, vector_store, query)
    
    #Compute RMSE and AUC-ROC for entire dataset
    #compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, 10)
    
    # Launch the Gradio app
    launch_gradio(vector_store, dataset, gen_llm, val_llm)

    logging.info("Finished!!!")

if __name__ == "__main__":
    main()