File size: 4,821 Bytes
a6e92fe
 
 
 
58e5d73
a6e92fe
781a06e
a6e92fe
 
 
eb67361
58e5d73
eb67361
a6e92fe
 
b99886d
a6e92fe
 
 
 
 
9b55c5a
a6e92fe
 
 
 
 
 
 
 
 
 
58e5d73
9b55c5a
456fb95
 
a6e92fe
 
 
 
 
 
9b55c5a
b99886d
 
 
a6e92fe
58e5d73
 
a6e92fe
58e5d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e92fe
58e5d73
 
 
 
 
 
 
a6e92fe
 
b99886d
a6e92fe
58e5d73
4a793d8
58e5d73
 
 
 
 
 
a6e92fe
58e5d73
 
 
 
 
 
b99886d
58e5d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b99886d
a6e92fe
 
b99886d
 
 
a6e92fe
b99886d
 
 
546fe9e
b99886d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Databricks notebook source
import streamlit as st
import os
import yaml
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from dotenv import load_dotenv
import torch
from src.generator import answer_with_rag
from ragatouille import RAGPretrainedModel
from src.data_preparation import split_documents
from src.embeddings import init_embedding_model
from langchain_nvidia_ai_endpoints.embeddings import NVIDIAEmbeddings

from transformers import pipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from src.retriever import init_vectorDB_from_doc, retriever
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langchain_community.vectorstores import FAISS
import faiss
def load_config():
    with open("./config.yml","r") as file_object:
        try:
            cfg=yaml.safe_load(file_object)
            
        except yaml.YAMLError as exc:
            logger.error(str(exc))
            raise
        else:
            return cfg

cfg= load_config()
#os.environ['NVIDIA_API_KEY']=st.secrets("NVIDIA_API_KEY")
#load_dotenv("./src/.env")
#HF_TOKEN=os.environ.get["HF_TOKEN"] 
#st.write(os.environ["HF_TOKEN"] == st.secrets["HF_TOKEN"])
EMBEDDING_MODEL_NAME=cfg['EMBEDDING_MODEL_NAME']
DATA_FILE_PATH=cfg['DATA_FILE_PATH']
READER_MODEL_NAME=cfg['READER_MODEL_NAME']
RERANKER_MODEL_NAME=cfg['RERANKER_MODEL_NAME']
VECTORDB_PATH=cfg['VECTORDB_PATH']


def main():
    st.title("Un RAG pour interroger le Collège de Pédiatrie 2024")
    user_query = st.text_input("Entrez votre question:")

    if KNOWLEDGE_VECTOR_DATABASE not in st.session_state:
        # Initialize the retriever and LLM
    
        st.session_state.loader = PyPDFLoader(DATA_FILE_PATH)
        #loader = PyPDFDirectoryLoader(DATA_FILE_PATH)
        st.session_state.raw_document_base = st.session_state.loader.load()
        st.session_state.MARKDOWN_SEPARATORS = [
            "\n#{1,6} ",
            "```\n",
            "\n\\*\\*\\*+\n",
            "\n---+\n",
            "\n___+\n",
            "\n\n",
            "\n",
            " ",
            "",]
        st.session_state.docs_processed = split_documents(
            512,  # We choose a chunk size adapted to our model
            st.session_state.raw_document_base,
            #tokenizer_name=EMBEDDING_MODEL_NAME,
            separator=st.session_state.MARKDOWN_SEPARATORS
        )
        st.session_state.embedding_model=NVIDIAEmbeddings()
        st.session_state.KNOWLEDGE_VECTOR_DATABASE= init_vectorDB_from_doc(st.session_state.docs_processed, 
                                                                           st.session_state.embedding_model)
    
    #if os.path.exists(VECTORDB_PATH):
      #  KNOWLEDGE_VECTOR_DATABASE = FAISS.load_local(
      #      VECTORDB_PATH, embedding_model, 
       #     allow_dangerous_deserialization=True)
    #else:
        #KNOWLEDGE_VECTOR_DATABASE=init_vectorDB_from_doc(docs_processed, embedding_model)
       # KNOWLEDGE_VECTOR_DATABASE.save_local(VECTORDB_PATH)


    if st.button("Get Answer"):
    # Get the answer and relevant documents
        #bnb_config = BitsAndBytesConfig(
            #load_in_8bit=True
         #   load_in_4bit=True,
         #   bnb_4bit_use_double_quant=True,
         #   bnb_4bit_quant_type="nf4",
         #   bnb_4bit_compute_dtype=torch.bfloat16,
        #)   
        

    llm = ChatNVIDIA(
        model=READER_MODEL_NAME,
        api_key= os.get("NVIDIA_API_KEY"), 
        temperature=0.2,
        top_p=0.7,
        max_tokens=1024,
        )
        #tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)

        #READER_LLM = pipeline(
        #    model=model,
        #    tokenizer=tokenizer,
        #    task="text-generation",
        #    do_sample=True,
       #     temperature=0.2,
        #    repetition_penalty=1.1,
        #    return_full_text=False,
        #    max_new_tokens=500,
        #    token = os.getenv("HF_TOKEN")
       # )
       # RERANKER = RAGPretrainedModel.from_pretrained(RERANKER_MODEL_NAME)
       # num_doc_before_rerank=15
       # num_final_releveant_docs=5
       # answer, relevant_docs = answer_with_rag(query=user_query, READER_MODEL_NAME=READER_MODEL_NAME,embedding_model=embedding_model,vectorDB=KNOWLEDGE_VECTOR_DATABASE,reranker=RERANKER, llm=READER_LLM,num_doc_before_rerank=num_doc_before_rerank,num_final_relevant_docs=num_final_releveant_docs,rerank=True)
        #print(answer) 

    
        # Display the answer
        st.write("### Answer:")
        st.write(answer)
    
        # Display the relevant documents
        st.write("### Relevant Documents:")
        for i, doc in enumerate(relevant_docs):
            st.write(f"Document {i}:\n{doc}")


if __name__ == "__main__":
    main()