heyday1234 commited on
Commit
edbf017
·
verified ·
1 Parent(s): 72e4f8f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +111 -110
model.py CHANGED
@@ -1,110 +1,111 @@
1
- """
2
- The code in this script subjects to a licence of 96harsh52/LLaMa_2_chatbot (https://github.com/96harsh52/LLaMa_2_chatbot)
3
- Youtube instruction (https://www.youtube.com/watch?v=kXuHxI5ZcG0&list=PLrLEqwuz-mRIdQrfeCjeCyFZ-Pl6ffPIN&index=18)
4
- Llama 2 Model (Quantized one by the Bloke): https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q8_0.bin
5
- Llama 2 HF Model (Original One): https://huggingface.co/meta-llama
6
- Chainlit docs: https://github.com/Chainlit/chainlit
7
- """
8
-
9
- from langchain import PromptTemplate
10
- from langchain_community.embeddings import HuggingFaceEmbeddings
11
- from langchain_community.vectorstores import FAISS
12
- from langchain.chains import RetrievalQA
13
- from langchain_community.llms import CTransformers
14
- import chainlit as cl
15
-
16
- DB_FAISS_PATH = 'vectorstore/db_faiss'
17
-
18
- custom_prompt_template = """Use the following pieces of information to answer the user's question.
19
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
20
-
21
- Context: {context}
22
- Question: {question}
23
-
24
- Only return the helpful answer below and nothing else.
25
- Helpful answer:
26
- """
27
-
28
-
29
- def set_custom_prompt():
30
- """
31
- Prompt template for QA retrieval for each vectorstore
32
- """
33
- prompt = PromptTemplate(template=custom_prompt_template,
34
- input_variables=['context', 'question'])
35
- return prompt
36
-
37
-
38
- def load_llm():
39
- """
40
- Load the language model
41
- """
42
- llm = CTransformers(model='llama-2-7b-chat.ggmlv3.q8_0.bin',
43
- model_type='llama',
44
- max_new_tokens=512,
45
- temperature=0.5)
46
- return llm
47
-
48
-
49
- def retrieval_qa_chain(llm, prompt, db):
50
- """
51
- Create a retrieval QA chain
52
- """
53
- qa_chain = RetrievalQA.from_chain_type(
54
- llm=llm,
55
- chain_type='stuff',
56
- retriever=db.as_retriever(search_kwargs={'k': 2}),
57
- return_source_documents=True,
58
- chain_type_kwargs={'prompt': prompt}
59
- )
60
- return qa_chain
61
-
62
-
63
- def qa_bot():
64
- """
65
- Create a QA bot
66
- """
67
- embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
68
- model_kwargs={'device': 'cpu'})
69
- db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
70
- llm = load_llm()
71
- qa_prompt = set_custom_prompt()
72
- qa = retrieval_qa_chain(llm, qa_prompt, db)
73
- return qa
74
-
75
-
76
- def final_result(query):
77
- qa_result = qa_bot()
78
- response = qa_result({'query': query})
79
- return response
80
-
81
-
82
- @cl.on_chat_start
83
- async def start():
84
- chain = qa_bot()
85
- msg = cl.Message(content="Starting the bot...")
86
- await msg.send()
87
- msg.content = "Hi, Welcome to Medical Chatbot. What is your query?"
88
- await msg.update()
89
- cl.user_session.set("chain", chain)
90
-
91
-
92
- @cl.on_message
93
- async def main(message: cl.Message):
94
- chain = cl.user_session.get("chain")
95
- cb = cl.AsyncLangchainCallbackHandler(
96
- stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
97
- )
98
- cb.answer_reached = True
99
- res = await chain.acall(message.content, callbacks=[cb])
100
- answer = res["result"]
101
- sources = res["source_documents"]
102
-
103
- if sources:
104
- answer += f"\nSources:" + str(sources)
105
- else:
106
- answer += "\nNo sources found"
107
-
108
- await cl.Message(content=answer).send()
109
-
110
-
 
 
1
+ """
2
+ The code in this script subjects to a licence of 96harsh52/LLaMa_2_chatbot (https://github.com/96harsh52/LLaMa_2_chatbot)
3
+ Youtube instruction (https://www.youtube.com/watch?v=kXuHxI5ZcG0&list=PLrLEqwuz-mRIdQrfeCjeCyFZ-Pl6ffPIN&index=18)
4
+ Llama 2 Model (Quantized one by the Bloke): https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q8_0.bin
5
+ Llama 2 HF Model (Original One): https://huggingface.co/meta-llama
6
+ Chainlit docs: https://github.com/Chainlit/chainlit
7
+ """
8
+
9
+ from langchain import PromptTemplate
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain.chains import RetrievalQA
13
+ from langchain_community.llms import CTransformers
14
+ import chainlit as cl
15
+
16
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
17
+
18
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
19
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
20
+
21
+ Context: {context}
22
+ Question: {question}
23
+
24
+ Only return the helpful answer below and nothing else.
25
+ Helpful answer:
26
+ """
27
+
28
+
29
+ def set_custom_prompt():
30
+ """
31
+ Prompt template for QA retrieval for each vectorstore
32
+ """
33
+ prompt = PromptTemplate(template=custom_prompt_template,
34
+ input_variables=['context', 'question'])
35
+ return prompt
36
+
37
+
38
+ def load_llm():
39
+ """
40
+ Load the language model
41
+ """
42
+ llm = CTransformers(model='TheBloke/Llama-2-7b-Chat-GGUF',
43
+ model_file='llama-2-7b-chat.Q8_0.gguf',
44
+ model_type='llama',
45
+ max_new_tokens=512,
46
+ temperature=0.5)
47
+ return llm
48
+
49
+
50
+ def retrieval_qa_chain(llm, prompt, db):
51
+ """
52
+ Create a retrieval QA chain
53
+ """
54
+ qa_chain = RetrievalQA.from_chain_type(
55
+ llm=llm,
56
+ chain_type='stuff',
57
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
58
+ return_source_documents=True,
59
+ chain_type_kwargs={'prompt': prompt}
60
+ )
61
+ return qa_chain
62
+
63
+
64
+ def qa_bot():
65
+ """
66
+ Create a QA bot
67
+ """
68
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
69
+ model_kwargs={'device': 'cpu'})
70
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
71
+ llm = load_llm()
72
+ qa_prompt = set_custom_prompt()
73
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
74
+ return qa
75
+
76
+
77
+ def final_result(query):
78
+ qa_result = qa_bot()
79
+ response = qa_result({'query': query})
80
+ return response
81
+
82
+
83
+ @cl.on_chat_start
84
+ async def start():
85
+ chain = qa_bot()
86
+ msg = cl.Message(content="Starting the bot...")
87
+ await msg.send()
88
+ msg.content = "Hi, Welcome to Medical Chatbot. What is your query?"
89
+ await msg.update()
90
+ cl.user_session.set("chain", chain)
91
+
92
+
93
+ @cl.on_message
94
+ async def main(message: cl.Message):
95
+ chain = cl.user_session.get("chain")
96
+ cb = cl.AsyncLangchainCallbackHandler(
97
+ stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
98
+ )
99
+ cb.answer_reached = True
100
+ res = await chain.acall(message.content, callbacks=[cb])
101
+ answer = res["result"]
102
+ sources = res["source_documents"]
103
+
104
+ if sources:
105
+ answer += f"\nSources:" + str(sources)
106
+ else:
107
+ answer += "\nNo sources found"
108
+
109
+ await cl.Message(content=answer).send()
110
+
111
+