Somnath3570 commited on
Commit
4f260fc
·
verified ·
1 Parent(s): a6def53

Upload 3 files

Browse files
Files changed (3) hide show
  1. connect_memory_with_llm.py +63 -0
  2. create_memory_for_llm.py +46 -0
  3. medibot.py +102 -0
connect_memory_with_llm.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain_huggingface import HuggingFaceEndpoint
4
+ from langchain_core.prompts import PromptTemplate
5
+ from langchain.chains import RetrievalQA
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores import FAISS
8
+
9
+ ## Uncomment the following files if you're not using pipenv as your virtual environment manager
10
+ from dotenv import load_dotenv, find_dotenv
11
+ load_dotenv(find_dotenv())
12
+
13
+
14
+ # Step 1: Setup LLM (Mistral with HuggingFace)
15
+ HF_TOKEN=os.environ.get("HF_TOKEN")
16
+ HUGGINGFACE_REPO_ID="mistralai/Mistral-7B-Instruct-v0.3"
17
+
18
+ def load_llm(huggingface_repo_id):
19
+ llm=HuggingFaceEndpoint(
20
+ repo_id=huggingface_repo_id,
21
+ task="text-generation",
22
+ temperature=0.5,
23
+ model_kwargs={"token":HF_TOKEN,
24
+ "max_length":512}
25
+ )
26
+ return llm
27
+
28
+ # Step 2: Connect LLM with FAISS and Create chain
29
+
30
+ CUSTOM_PROMPT_TEMPLATE = """
31
+ Use the pieces of information provided in the context to answer user's question.
32
+ If you dont know the answer, just say that you dont know, dont try to make up an answer.
33
+ Dont provide anything out of the given context
34
+
35
+ Context: {context}
36
+ Question: {question}
37
+
38
+ Start the answer directly. No small talk please.
39
+ """
40
+
41
+ def set_custom_prompt(custom_prompt_template):
42
+ prompt=PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
43
+ return prompt
44
+
45
+ # Load Database
46
+ DB_FAISS_PATH="vectorstore/db_faiss"
47
+ embedding_model=HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
48
+ db=FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
49
+
50
+ # Create QA chain
51
+ qa_chain=RetrievalQA.from_chain_type(
52
+ llm=load_llm(HUGGINGFACE_REPO_ID),
53
+ chain_type="stuff",
54
+ retriever=db.as_retriever(search_kwargs={'k':3}),
55
+ return_source_documents=True,
56
+ chain_type_kwargs={'prompt':set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
57
+ )
58
+
59
+ # Now invoke with a single query
60
+ user_query=input("Write Query Here: ")
61
+ response=qa_chain.invoke({'query': user_query})
62
+ print("RESULT: ", response["result"])
63
+ print("SOURCE DOCUMENTS: ", response["source_documents"])
create_memory_for_llm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
+
6
+ ## Uncomment the following files if you're not using pipenv as your virtual environment manager
7
+ from dotenv import load_dotenv, find_dotenv
8
+ load_dotenv(find_dotenv())
9
+
10
+
11
+ # Step 1: Load raw PDF(s)
12
+ DATA_PATH="data/"
13
+ def load_pdf_files(data):
14
+ loader = DirectoryLoader(data,
15
+ glob='*.pdf',
16
+ loader_cls=PyPDFLoader)
17
+
18
+ documents=loader.load()
19
+ return documents
20
+
21
+ documents=load_pdf_files(data=DATA_PATH)
22
+ #print("Length of PDF pages: ", len(documents))
23
+
24
+
25
+ # Step 2: Create Chunks
26
+ def create_chunks(extracted_data):
27
+ text_splitter=RecursiveCharacterTextSplitter(chunk_size=500,
28
+ chunk_overlap=50)
29
+ text_chunks=text_splitter.split_documents(extracted_data)
30
+ return text_chunks
31
+
32
+ text_chunks=create_chunks(extracted_data=documents)
33
+ #print("Length of Text Chunks: ", len(text_chunks))
34
+
35
+ # Step 3: Create Vector Embeddings
36
+
37
+ def get_embedding_model():
38
+ embedding_model=HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
39
+ return embedding_model
40
+
41
+ embedding_model=get_embedding_model()
42
+
43
+ # Step 4: Store embeddings in FAISS
44
+ DB_FAISS_PATH="vectorstore/db_faiss"
45
+ db=FAISS.from_documents(text_chunks, embedding_model)
46
+ db.save_local(DB_FAISS_PATH)
medibot.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ # Update these imports
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain.chains import RetrievalQA
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_huggingface import HuggingFaceEndpoint
9
+
10
+ from dotenv import load_dotenv, find_dotenv
11
+ load_dotenv(find_dotenv())
12
+
13
+ DB_FAISS_PATH = "vectorstore/db_faiss"
14
+
15
+ @st.cache_resource
16
+ def get_vectorstore():
17
+ embedding_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
18
+ db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
19
+ return db
20
+
21
+ def set_custom_prompt(custom_prompt_template):
22
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
23
+ return prompt
24
+
25
+ def load_llm(huggingface_repo_id, HF_TOKEN):
26
+ llm = HuggingFaceEndpoint(
27
+ repo_id=huggingface_repo_id,
28
+ task="text-generation", # Add this line
29
+ temperature=0.5,
30
+ model_kwargs={
31
+ "token": HF_TOKEN,
32
+ "max_length": 512 # Changed to integer
33
+ }
34
+ )
35
+ return llm
36
+
37
+ def main():
38
+ st.title("Ask Chatbot!")
39
+
40
+ if 'messages' not in st.session_state:
41
+ st.session_state.messages = []
42
+
43
+ for message in st.session_state.messages:
44
+ st.chat_message(message['role']).markdown(message['content'])
45
+
46
+ prompt = st.chat_input("Pass your prompt here")
47
+
48
+ if prompt:
49
+ st.chat_message('user').markdown(prompt)
50
+ st.session_state.messages.append({'role': 'user', 'content': prompt})
51
+
52
+ CUSTOM_PROMPT_TEMPLATE = """
53
+ Use the pieces of information provided in the context to answer user's question.
54
+ If you dont know the answer, just say that you dont know, dont try to make up an answer.
55
+
56
+ Dont provide anything out of the given context
57
+
58
+ Context: {context}
59
+ Question: {question}
60
+
61
+ Start the answer directly. No small talk please.
62
+ """
63
+
64
+ HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
65
+ HF_TOKEN = os.environ.get("HF_TOKEN")
66
+
67
+ try:
68
+ with st.spinner("Thinking..."): # Add loading indicator
69
+ vectorstore = get_vectorstore()
70
+ if vectorstore is None:
71
+ st.error("Failed to load the vector store")
72
+ return
73
+
74
+ qa_chain = RetrievalQA.from_chain_type(
75
+ llm=load_llm(huggingface_repo_id=HUGGINGFACE_REPO_ID, HF_TOKEN=HF_TOKEN),
76
+ chain_type="stuff",
77
+ retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
78
+ return_source_documents=True,
79
+ chain_type_kwargs={'prompt': set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
80
+ )
81
+
82
+ response = qa_chain.invoke({'query': prompt})
83
+
84
+ result = response["result"]
85
+ source_documents = response["source_documents"]
86
+
87
+ # Format source documents more cleanly
88
+ source_docs_text = "\n\n**Source Documents:**\n"
89
+ for i, doc in enumerate(source_documents, 1):
90
+ source_docs_text += f"{i}. Page {doc.metadata.get('page', 'N/A')}: {doc.page_content[:200]}...\n\n"
91
+
92
+ result_to_show = f"{result}\n{source_docs_text}"
93
+
94
+ st.chat_message('assistant').markdown(result_to_show)
95
+ st.session_state.messages.append({'role': 'assistant', 'content': result_to_show})
96
+
97
+ except Exception as e:
98
+ st.error(f"Error: {str(e)}")
99
+ st.error("Please check your HuggingFace token and model access permissions")
100
+
101
+ if __name__ == "__main__":
102
+ main()