GovindRaj commited on
Commit
1aa1394
·
1 Parent(s): 18b55b9

added changes

Browse files
Files changed (3) hide show
  1. app.py +83 -0
  2. requirements.txt +11 -0
  3. vectorstore/db_faiss/model.py +105 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+
9
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
10
+
11
+ custom_prompt_template = """Use the following pieces of information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
12
+
13
+ Context: {context}
14
+ Question: {question}
15
+
16
+ Only return the helpful answer below and nothing else.
17
+ Helpful answer: """
18
+
19
+ def set_custom_prompt():
20
+ prompt = PromptTemplate(template=custom_prompt_template,
21
+ input_variables=['context', 'question'])
22
+ return prompt
23
+
24
+ def retrieval_qa_chain(llm, prompt, db):
25
+ qa_chain = RetrievalQA.from_chain_type(
26
+ llm=llm,
27
+ chain_type='stuff',
28
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
29
+ return_source_documents=True,
30
+ chain_type_kwargs={'prompt': prompt}
31
+ )
32
+ return qa_chain
33
+
34
+ def load_llm():
35
+ model_path = "/home/ebiz/Govind/Llama2-Medical-Chatbot/llama-2-7b-chat.ggmlv3.q4_0.bin"
36
+ llm = CTransformers(
37
+ model=model_path,
38
+ model_type="llama",
39
+ max_new_tokens=1024,
40
+ temperature=0.5
41
+ )
42
+ return llm
43
+
44
+ def qa_bot():
45
+ embeddings = HuggingFaceEmbeddings(
46
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
47
+ model_kwargs={'device': 'cpu'}
48
+ )
49
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
50
+ llm = load_llm()
51
+ qa_prompt = set_custom_prompt()
52
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
53
+ return qa
54
+
55
+ def main():
56
+ st.title("Medical Chatbot")
57
+
58
+ # Initialize session state for chat history
59
+ if "messages" not in st.session_state:
60
+ st.session_state.messages = []
61
+
62
+ # Display chat history
63
+ for message in st.session_state.messages:
64
+ with st.chat_message(message["role"]):
65
+ st.markdown(message["content"])
66
+
67
+ # Chat input
68
+ if prompt := st.chat_input("What is your medical query?"):
69
+ # Display user message
70
+ st.session_state.messages.append({"role": "user", "content": prompt})
71
+ with st.chat_message("user"):
72
+ st.markdown(prompt)
73
+
74
+ # Generate and display assistant response
75
+ with st.chat_message("assistant"):
76
+ with st.spinner("Thinking..."):
77
+ qa_chain = qa_bot()
78
+ response = qa_chain({'query': prompt})
79
+ st.markdown(response["result"])
80
+ st.session_state.messages.append({"role": "assistant", "content": response["result"]})
81
+
82
+ if __name__ == '__main__':
83
+ main()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pypdf
2
+ langchain
3
+ torch
4
+ accelerate
5
+ bitsandbytes
6
+ ctransformers
7
+ sentence_transformers
8
+ faiss_cpu
9
+ chainlit
10
+ huggingface_hub
11
+ langchain_community
vectorstore/db_faiss/model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_community.llms import CTransformers
6
+ from langchain.chains import RetrievalQA
7
+ import chainlit as cl
8
+
9
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
10
+
11
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
12
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
13
+
14
+ Context: {context}
15
+ Question: {question}
16
+
17
+ Only return the helpful answer below and nothing else.
18
+ Helpful answer:
19
+ """
20
+
21
+ def set_custom_prompt():
22
+ """
23
+ Prompt template for QA retrieval for each vectorstore
24
+ """
25
+ prompt = PromptTemplate(template=custom_prompt_template,
26
+ input_variables=['context', 'question'])
27
+ return prompt
28
+
29
+ #Retrieval QA Chain
30
+ def retrieval_qa_chain(llm, prompt, db):
31
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
32
+ chain_type='stuff',
33
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
34
+ return_source_documents=True,
35
+ chain_type_kwargs={'prompt': prompt}
36
+ )
37
+ return qa_chain
38
+
39
+ #Loading the model
40
+ def load_llm():
41
+ # Load the locally downloaded model here
42
+
43
+ # Path to the specific GGML model file you want to use
44
+ model_path = "/home/ebiz/Govind/Llama2-Medical-Chatbot/llama-2-7b-chat.ggmlv3.q4_0.bin"
45
+
46
+
47
+ llm = CTransformers(
48
+ model = model_path,
49
+ model_type="llama",
50
+ max_new_tokens = 1024,
51
+ temperature = 0.5
52
+ )
53
+ return llm
54
+
55
+ #QA Model Function
56
+ def qa_bot():
57
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
58
+ model_kwargs={'device': 'cpu'})
59
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
60
+ llm = load_llm()
61
+ qa_prompt = set_custom_prompt()
62
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
63
+
64
+ return qa
65
+
66
+ #output function
67
+ def final_result(query):
68
+ qa_result = qa_bot()
69
+ response = qa_result({'query': query})
70
+ return response
71
+
72
+ #chainlit code
73
+ @cl.on_chat_start
74
+ async def start():
75
+ chain = qa_bot()
76
+ msg = cl.Message(content="Starting the bot...")
77
+ await msg.send()
78
+ msg.content = "Hi, Welcome to Medical Bot. What is your query?"
79
+ await msg.update()
80
+
81
+ cl.user_session.set("chain", chain)
82
+
83
+ @cl.on_message
84
+ async def main(message: cl.Message):
85
+ chain = cl.user_session.get("chain")
86
+
87
+ # Disable streaming to avoid duplicate answers
88
+ cb = cl.AsyncLangchainCallbackHandler(
89
+ stream_final_answer=False # Disable streaming to prevent multiple responses
90
+ )
91
+
92
+ res = await chain.acall(message.content, callbacks=[cb])
93
+ answer = res["result"]
94
+ sources = res.get("source_documents", [])
95
+
96
+ # Ensure the answer is sent once and with sources if available
97
+ # if sources:
98
+ # source_info = "\nSources:\n" + "\n".join([doc.metadata.get("source", "Unknown") for doc in sources])
99
+ # answer += source_info
100
+ # else:
101
+ # answer += "\nNo sources found"
102
+
103
+ await cl.Message(content=answer).send()
104
+
105
+