RAG_PEDIATRICS / app.py
Stéphanie Kamgnia Wonkap
fixing app.py
2b8d974
# Databricks notebook source
import streamlit as st
import os
import yaml
from dotenv import load_dotenv
import torch
from src.generator import answer_with_rag
from ragatouille import RAGPretrainedModel
from src.data_preparation import split_documents
from src.embeddings import init_embedding_model
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
from transformers import pipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from src.retriever import init_vectorDB_from_doc, retriever
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langchain_community.vectorstores import FAISS
import faiss
def load_config():
with open("./config.yml","r") as file_object:
try:
cfg=yaml.safe_load(file_object)
except yaml.YAMLError as exc:
logger.error(str(exc))
raise
else:
return cfg
cfg= load_config()
#os.environ['NVIDIA_API_KEY']=st.secrets("NVIDIA_API_KEY")
#load_dotenv("./src/.env")
#HF_TOKEN=os.environ.get["HF_TOKEN"]
#st.write(os.environ["HF_TOKEN"] == st.secrets["HF_TOKEN"])
EMBEDDING_MODEL_NAME=cfg['EMBEDDING_MODEL_NAME']
DATA_FILE_PATH=cfg['DATA_FILE_PATH']
READER_MODEL_NAME=cfg['READER_MODEL_NAME']
RERANKER_MODEL_NAME=cfg['RERANKER_MODEL_NAME']
VECTORDB_PATH=cfg['VECTORDB_PATH']
def main():
st.title("Un RAG pour interroger le Collège de Pédiatrie 2024")
user_query = st.text_input("Entrez votre question:")
if "KNOWLEDGE_VECTOR_DATABASE" not in st.session_state:
# Initialize the retriever and LLM
st.session_state.loader = PyPDFLoader(DATA_FILE_PATH)
#loader = PyPDFDirectoryLoader(DATA_FILE_PATH)
st.session_state.raw_document_base = st.session_state.loader.load()
st.session_state.MARKDOWN_SEPARATORS = [
"\n#{1,6} ",
"```\n",
"\n\\*\\*\\*+\n",
"\n---+\n",
"\n___+\n",
"\n\n",
"\n",
" ",
"",]
st.session_state.docs_processed = split_documents(
400, # We choose a chunk size adapted to our model
st.session_state.raw_document_base,
#tokenizer_name=EMBEDDING_MODEL_NAME,
separator=st.session_state.MARKDOWN_SEPARATORS
)
st.session_state.embedding_model=NVIDIAEmbeddings(model="NV-Embed-QA", truncate="END")
st.session_state.KNOWLEDGE_VECTOR_DATABASE= init_vectorDB_from_doc(st.session_state.docs_processed,
st.session_state.embedding_model)
if (user_query) and (st.button("Get Answer")):
num_doc_before_rerank=5
st.session_state.retriever= st.session_state.KNOWLEDGE_VECTOR_DATABASE.as_retriever(search_type="similarity",
search_kwargs={"k": num_doc_before_rerank})
st.write("### Please wait while we are getting the answer.....")
llm = ChatNVIDIA(
model=READER_MODEL_NAME,
api_key= os.getenv("NVIDIA_API_KEY"),
temperature=0.2,
top_p=0.7,
max_tokens=1024,
)
answer, relevant_docs = answer_with_rag(query=user_query, llm=llm, retriever=st.session_state.retriever)
st.write("### Answer:")
st.write(answer)
# Display the relevant documents
st.write("### Relevant Documents:")
for i, doc in enumerate(relevant_docs):
st.write(f"Document {i}:\n{doc}")
if __name__ == "__main__":
main()