ThisIs-Developer commited on
Commit
0cffca7
1 Parent(s): ea4c98d

Upload model.py

Browse files
Files changed (1) hide show
  1. Streamlit/model.py +132 -0
Streamlit/model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain import PromptTemplate
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+
9
+ DB_FAISS_PATH = 'vectorstores/db_faiss'
10
+
11
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
12
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
13
+
14
+ Context: {context}
15
+ Question: {question}
16
+
17
+ Only return the helpful answer below and nothing else.
18
+ Helpful answer:
19
+ """
20
+
21
+ def set_custom_prompt():
22
+ prompt = PromptTemplate(template=custom_prompt_template,
23
+ input_variables=['context', 'question'])
24
+ return prompt
25
+
26
+ def retrieval_qa_chain(llm, prompt, db):
27
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
28
+ chain_type='stuff',
29
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
30
+ return_source_documents=True,
31
+ chain_type_kwargs={'prompt': prompt}
32
+ )
33
+ return qa_chain
34
+
35
+ def load_llm():
36
+ llm = CTransformers(
37
+ model="TheBloke/Llama-2-7B-Chat-GGML",
38
+ model_type="llama",
39
+ max_new_tokens=512,
40
+ temperature=0.5
41
+ )
42
+ return llm
43
+
44
+ def qa_bot(query):
45
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
46
+ model_kwargs={'device': 'cpu'})
47
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
48
+ llm = load_llm()
49
+ qa_prompt = set_custom_prompt()
50
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
51
+
52
+ # Implement the question-answering logic here
53
+ response = qa({'query': query})
54
+ return response['result']
55
+
56
+ def add_vertical_space(spaces=1):
57
+ for _ in range(spaces):
58
+ st.markdown("---")
59
+
60
+ def main():
61
+ st.set_page_config(page_title="Llama-2-GGML Medical Chatbot")
62
+
63
+ with st.sidebar:
64
+ st.title('Llama-2-GGML Medical Chatbot! 馃殌馃')
65
+ st.markdown('''
66
+ ## About
67
+
68
+ The Llama-2-GGML Medical Chatbot uses the **Llama-2-7B-Chat-GGML** model and was trained on medical data from **"The GALE ENCYCLOPEDIA of MEDICINE"**.
69
+
70
+ ### 馃攧Bot evolving, stay tuned!
71
+ ## Useful Links 馃敆
72
+
73
+ - **Model:** [Llama-2-7B-Chat-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML) 馃摎
74
+ - **GitHub:** [ThisIs-Developer/Llama-2-GGML-Medical-Chatbot](https://github.com/ThisIs-Developer/Llama-2-GGML-Medical-Chatbot) 馃挰
75
+ ''')
76
+ add_vertical_space(1) # Adjust the number of spaces as needed
77
+ st.write('Made by [@ThisIs-Developer](https://huggingface.co/ThisIs-Developer)')
78
+
79
+ st.title("Llama-2-GGML Medical Chatbot")
80
+ st.markdown(
81
+ """
82
+ <style>
83
+ .chat-container {
84
+ display: flex;
85
+ flex-direction: column;
86
+ height: 400px;
87
+ overflow-y: auto;
88
+ padding: 10px;
89
+ color: white; /* Font color */
90
+ }
91
+ .user-bubble {
92
+ background-color: #007bff; /* Blue color for user */
93
+ align-self: flex-end;
94
+ border-radius: 10px;
95
+ padding: 8px;
96
+ margin: 5px;
97
+ max-width: 70%;
98
+ word-wrap: break-word;
99
+ }
100
+ .bot-bubble {
101
+ background-color: #363636; /* Slightly lighter background color */
102
+ align-self: flex-start;
103
+ border-radius: 10px;
104
+ padding: 8px;
105
+ margin: 5px;
106
+ max-width: 70%;
107
+ word-wrap: break-word;
108
+ }
109
+ </style>
110
+ """
111
+ , unsafe_allow_html=True)
112
+
113
+ conversation = st.session_state.get("conversation", [])
114
+
115
+ query = st.text_input("Ask your question here:", key="user_input")
116
+ if st.button("Get Answer"):
117
+ if query:
118
+ with st.spinner("Processing your question..."): # Display the processing message
119
+ conversation.append({"role": "user", "message": query})
120
+ # Call your QA function
121
+ answer = qa_bot(query)
122
+ conversation.append({"role": "bot", "message": answer})
123
+ st.session_state.conversation = conversation
124
+ else:
125
+ st.warning("Please input a question.")
126
+
127
+ chat_container = st.empty()
128
+ chat_bubbles = ''.join([f'<div class="{c["role"]}-bubble">{c["message"]}</div>' for c in conversation])
129
+ chat_container.markdown(f'<div class="chat-container">{chat_bubbles}</div>', unsafe_allow_html=True)
130
+
131
+ if __name__ == "__main__":
132
+ main()