|
|
|
import streamlit as st |
|
import os |
|
import yaml |
|
from dotenv import load_dotenv |
|
from src.generator import answer_with_rag |
|
from ragatouille import RAGPretrainedModel |
|
from src.data_preparation import split_documents |
|
from transformers import pipeline |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from src.retriever import init_vectorDB_from_doc, retriever |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from langchain_community.vectorstores import FAISS |
|
import faiss |
|
def load_config(): |
|
with open("./src/config.yml","r") as file_object: |
|
try: |
|
cfg=yaml.safe_load(file_object) |
|
|
|
except yaml.YAMLError as exc: |
|
logger.error(str(exc)) |
|
raise |
|
else: |
|
return cfg |
|
|
|
cfg= load_config() |
|
load_dotenv("./src/.env") |
|
|
|
EMBEDDING_MODEL_NAME=cfg['EMBEDDING_MODEL_NAME'] |
|
DATA_FILE_PATH=cfg['DATA_FILE_PATH'] |
|
READER_MODEL_NAME=cfg['READER_MODEL_NAME'] |
|
RERANKER_MODEL_NAME=cfg['RERANKER_MODEL_NAME'] |
|
VECTORDB_PATH=cfg['VECTORDB_PATH'] |
|
if __name__ == "__main__": |
|
st.title("RAG App to query le College de Pédiatrie") |
|
|
|
user_query = st.text_input("Entrez votre question:") |
|
|
|
|
|
|
|
|
|
loader = PyPDFLoader(DATA_FILE_PATH) |
|
|
|
raw_document_base = loader.load() |
|
MARKDOWN_SEPARATORS = [ |
|
"\n#{1,6} ", |
|
"```\n", |
|
"\n\\*\\*\\*+\n", |
|
"\n---+\n", |
|
"\n___+\n", |
|
"\n\n", |
|
"\n", |
|
" ", |
|
"",] |
|
docs_processed = split_documents( |
|
512, |
|
raw_document_base, |
|
tokenizer_name=EMBEDDING_MODEL_NAME, |
|
separator=MARKDOWN_SEPARATORS |
|
) |
|
embedding_model=init_embedding_model(EMBEDDING_MODEL_NAME) |
|
|
|
if os.path.exists(VECTORDB_PATH): |
|
new_vector_store = FAISS.load_local( |
|
VECTORDB_PATH, embedding_model, |
|
allow_dangerous_deserialization=True) |
|
else: |
|
KNOWLEDGE_VECTOR_DATABASE=init_vectorDB_from_doc(docs_processed, embedding_model) |
|
KNOWLEDGE_VECTOR_DATABASE.save_local(VECTORDB_PATH) |
|
|
|
|
|
if st.button("Get Answer"): |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME, quantization_config=bnb_config) |
|
tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME) |
|
|
|
READER_LLM = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
do_sample=True, |
|
temperature=0.2, |
|
repetition_penalty=1.1, |
|
return_full_text=False, |
|
max_new_tokens=500, |
|
) |
|
RERANKER = RAGPretrainedModel.from_pretrained(RERANKER_MODEL_NAME) |
|
num_doc_before_rerank=15 |
|
num_final_releveant_docs=5 |
|
answer, relevant_docs = answer_with_rag(query=user_query, READER_MODEL_NAME=READER_MODEL_NAME,embedding_model=embedding_model,vectorDB=KNOWLEDGE_VECTOR_DATABASE,reranker=RERANKER, llm=READER_LLM,num_doc_before_rerank=num_doc_before_rerank,num_final_relevant_docs=num_final_releveant_docs,rerank=True) |
|
|
|
|
|
|
|
|
|
st.write("### Answer:") |
|
st.write(answer) |
|
|
|
|
|
st.write("### Relevant Documents:") |
|
for i, doc in enumerate(relevant_docs): |
|
st.write(f"Document {i}:\n{doc.text}") |