File size: 3,808 Bytes
a6e92fe
 
 
 
 
781a06e
a6e92fe
 
 
eb67361
3504448
eb67361
a6e92fe
 
b99886d
a6e92fe
 
 
 
 
9b55c5a
a6e92fe
 
 
 
 
 
 
 
 
 
58e5d73
9b55c5a
456fb95
 
a6e92fe
 
 
 
 
 
9b55c5a
b99886d
 
 
a6e92fe
3e70cf2
58e5d73
a6e92fe
58e5d73
 
 
 
 
 
 
 
 
 
 
 
 
 
7469d7c
58e5d73
 
 
 
7469d7c
58e5d73
 
3504448
7469d7c
12e199b
 
58e5d73
12e199b
3504448
 
2b8d974
3504448
 
 
 
 
b99886d
 
 
 
 
546fe9e
b99886d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Databricks notebook source
import streamlit as st
import os
import yaml
from dotenv import load_dotenv
import torch
from src.generator import answer_with_rag
from ragatouille import RAGPretrainedModel
from src.data_preparation import split_documents
from src.embeddings import init_embedding_model
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA

from transformers import pipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from src.retriever import init_vectorDB_from_doc, retriever
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langchain_community.vectorstores import FAISS
import faiss
def load_config():
    with open("./config.yml","r") as file_object:
        try:
            cfg=yaml.safe_load(file_object)
            
        except yaml.YAMLError as exc:
            logger.error(str(exc))
            raise
        else:
            return cfg

cfg= load_config()
#os.environ['NVIDIA_API_KEY']=st.secrets("NVIDIA_API_KEY")
#load_dotenv("./src/.env")
#HF_TOKEN=os.environ.get["HF_TOKEN"] 
#st.write(os.environ["HF_TOKEN"] == st.secrets["HF_TOKEN"])
EMBEDDING_MODEL_NAME=cfg['EMBEDDING_MODEL_NAME']
DATA_FILE_PATH=cfg['DATA_FILE_PATH']
READER_MODEL_NAME=cfg['READER_MODEL_NAME']
RERANKER_MODEL_NAME=cfg['RERANKER_MODEL_NAME']
VECTORDB_PATH=cfg['VECTORDB_PATH']


def main():
    st.title("Un RAG pour interroger le Collège de Pédiatrie 2024")
    user_query = st.text_input("Entrez votre question:")

    if "KNOWLEDGE_VECTOR_DATABASE" not in st.session_state:
        # Initialize the retriever and LLM
    
        st.session_state.loader = PyPDFLoader(DATA_FILE_PATH)
        #loader = PyPDFDirectoryLoader(DATA_FILE_PATH)
        st.session_state.raw_document_base = st.session_state.loader.load()
        st.session_state.MARKDOWN_SEPARATORS = [
            "\n#{1,6} ",
            "```\n",
            "\n\\*\\*\\*+\n",
            "\n---+\n",
            "\n___+\n",
            "\n\n",
            "\n",
            " ",
            "",]
        st.session_state.docs_processed = split_documents(
            400,  # We choose a chunk size adapted to our model
            st.session_state.raw_document_base,
            #tokenizer_name=EMBEDDING_MODEL_NAME,
            separator=st.session_state.MARKDOWN_SEPARATORS
        )
        st.session_state.embedding_model=NVIDIAEmbeddings(model="NV-Embed-QA", truncate="END")
        st.session_state.KNOWLEDGE_VECTOR_DATABASE= init_vectorDB_from_doc(st.session_state.docs_processed, 
                                                                           st.session_state.embedding_model)
    if (user_query) and (st.button("Get Answer")):
        num_doc_before_rerank=5 
        st.session_state.retriever= st.session_state.KNOWLEDGE_VECTOR_DATABASE.as_retriever(search_type="similarity",
                                                                                            search_kwargs={"k": num_doc_before_rerank})
        
        st.write("### Please wait while we are getting the answer.....")                                                                                 
        llm = ChatNVIDIA(
            model=READER_MODEL_NAME,
            api_key= os.getenv("NVIDIA_API_KEY"), 
            temperature=0.2,
            top_p=0.7,
            max_tokens=1024,
            )
        answer, relevant_docs = answer_with_rag(query=user_query, llm=llm, retriever=st.session_state.retriever)
        st.write("### Answer:")
        st.write(answer)
        # Display the relevant documents
        st.write("### Relevant Documents:")
        for i, doc in enumerate(relevant_docs):
            st.write(f"Document {i}:\n{doc}")


if __name__ == "__main__":
    main()