Poonawala commited on
Commit
86d1671
·
verified ·
1 Parent(s): 6abe096

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -149
app.py CHANGED
@@ -1,11 +1,73 @@
1
- import streamlit as st
2
- import PyPDF2
3
- from huggingface_hub import InferenceClient
4
-
5
- # Initialize the Inference Client
6
- client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")
7
-
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def respond(
10
  message,
11
  history: list[tuple[str, str]],
@@ -13,149 +75,79 @@ def respond(
13
  max_tokens,
14
  temperature,
15
  top_p,
16
- uploaded_pdf=None
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- # Add previous conversation history to the messages
21
- for val in history:
22
- if val[0]:
23
- messages.append({"role": "user", "content": val[0]})
24
- if val[1]:
25
- messages.append({"role": "assistant", "content": val[1]})
26
-
27
- # If a new message is entered, add it to the conversation history
28
- messages.append({"role": "user", "content": message})
29
-
30
- # If a PDF is uploaded, process its content
31
- if uploaded_pdf is not None:
32
- file_content = extract_pdf_text(uploaded_pdf)
33
- if file_content:
34
- messages.append({"role": "user", "content": f"Document Content: {file_content}"})
35
-
36
- # Get response from the model
37
- response = ""
38
- for message in client.chat_completion(
39
- messages,
40
- max_tokens=max_tokens,
41
- stream=True,
42
- temperature=temperature,
43
- top_p=top_p,
44
- ):
45
- token = message.choices[0].delta.content
46
- response += token
47
- yield response
48
-
49
-
50
- def extract_pdf_text(file):
51
- """Extract text from a PDF file."""
52
- try:
53
- reader = PyPDF2.PdfReader(file)
54
- text = ""
55
- for page in reader.pages:
56
- text += page.extract_text()
57
- return text
58
- except Exception as e:
59
- return f"Error extracting text from PDF: {str(e)}"
60
-
61
-
62
- # Streamlit UI
63
- st.set_page_config(page_title="Health Assistant", layout="wide")
64
-
65
- # Custom CSS for Streamlit app
66
- st.markdown(
67
- """
68
- <style>
69
- body {
70
- background-color: #1e2a38; /* Dark blue background */
71
- color: #ffffff; /* White text for readability */
72
- font-family: 'Arial', sans-serif; /* Clean and modern font */
73
- }
74
- .stButton button {
75
- background-color: #42B3CE !important; /* Light blue button */
76
- color: #2e3b4e !important; /* Dark text for contrast */
77
- border: none !important;
78
- padding: 10px 20px !important;
79
- border-radius: 8px !important;
80
- font-size: 16px;
81
- font-weight: bold;
82
- transition: background-color 0.3s ease, transform 0.2s ease;
83
- }
84
- .stButton button:hover {
85
- background-color: #3189A2 !important; /* Darker blue on hover */
86
- transform: scale(1.05);
87
- }
88
- .stTextInput input {
89
- background-color: #2f3b4d;
90
- color: white;
91
- border: 2px solid #42B3CE;
92
- padding: 12px;
93
- border-radius: 8px;
94
- font-size: 16px;
95
- transition: border 0.3s ease;
96
- }
97
- .stTextInput input:focus {
98
- border-color: #3189A2;
99
- }
100
- </style>
101
- """,
102
- unsafe_allow_html=True,
103
  )
104
 
105
- # Title and description
106
- st.title("Health Assistant Chat")
107
- st.subheader("Chat with your health assistant and upload a document for analysis")
108
-
109
- # System message for health-related responses
110
- system_message = (
111
- "You are a virtual health assistant designed to provide accurate and reliable information "
112
- "related to health, wellness, and medical topics. Your primary goal is to assist users with "
113
- "their health-related queries, offer general guidance, and suggest when to consult a licensed "
114
- "medical professional. If a user asks a question that is unrelated to health, wellness, or medical "
115
- "topics, respond politely but firmly with: 'I'm sorry, I can't help with that because I am a virtual "
116
- "health assistant designed to assist with health-related needs. Please let me know if you have any health-related questions.'"
117
  )
118
 
119
- # Upload a PDF file
120
- uploaded_pdf = st.file_uploader("Upload a PDF file (Optional)", type="pdf")
121
-
122
- # User input message
123
- message = st.text_input("Type your health-related question:")
124
-
125
- # History for conversation tracking
126
- if 'history' not in st.session_state:
127
- st.session_state['history'] = []
128
-
129
- # Collect and display previous conversation history
130
- history = st.session_state['history']
131
- for user_message, assistant_message in history:
132
- st.markdown(f"**You:** {user_message}")
133
- st.markdown(f"**Assistant:** {assistant_message}")
134
-
135
- # Max tokens, temperature, and top-p sliders
136
- max_tokens = st.slider("Max new tokens", min_value=1, max_value=2048, value=512)
137
- temperature = st.slider("Temperature", min_value=0.1, max_value=4.0, value=0.7, step=0.1)
138
- top_p = st.slider("Top-p (nucleus sampling)", min_value=0.1, max_value=1.0, value=0.95, step=0.05)
139
-
140
- # Button to generate response
141
- if st.button("Generate Response"):
142
- if message:
143
- # Append the user's question to the conversation history
144
- st.session_state.history.append((message, ""))
145
- # Generate the response based on the user's input and any uploaded document
146
- response = respond(
147
- message,
148
- st.session_state.history,
149
- system_message,
150
- max_tokens,
151
- temperature,
152
- top_p,
153
- uploaded_pdf
154
- )
155
- # Display the response
156
- for resp in response:
157
- st.markdown(f"**Assistant:** {resp}")
158
- # Update the conversation history with the assistant's response
159
- st.session_state.history[-1] = (message, resp)
160
- else:
161
- st.error("Please enter a question to proceed.")
 
1
+ import gradio as gr
2
+ from langchain_community.vectorstores import FAISS
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain_community.llms import HuggingFaceEndpoint
8
+ from langchain.chains import ConversationChain
9
+ from langchain.memory import ConversationBufferMemory
10
+ import os
11
+
12
+ api_token = os.getenv("HF_TOKEN")
13
+
14
+ # List of LLMs
15
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
16
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
17
+
18
+ # Load and split PDF documents
19
+ def load_doc(list_file_path):
20
+ loaders = [PyPDFLoader(x) for x in list_file_path]
21
+ pages = []
22
+ for loader in loaders:
23
+ pages.extend(loader.load())
24
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
25
+ doc_splits = text_splitter.split_documents(pages)
26
+ return doc_splits
27
+
28
+ # Create vector database
29
+ def create_db(splits):
30
+ embeddings = HuggingFaceEmbeddings()
31
+ vectordb = FAISS.from_documents(splits, embeddings)
32
+ return vectordb
33
+
34
+ # Initialize LLM chain
35
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
36
+ if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
37
+ llm = HuggingFaceEndpoint(
38
+ repo_id=llm_model,
39
+ huggingfacehub_api_token=api_token,
40
+ temperature=temperature,
41
+ max_new_tokens=max_tokens,
42
+ top_k=top_k,
43
+ )
44
+ else:
45
+ llm = HuggingFaceEndpoint(
46
+ huggingfacehub_api_token=api_token,
47
+ repo_id=llm_model,
48
+ temperature=temperature,
49
+ max_new_tokens=max_tokens,
50
+ top_k=top_k,
51
+ )
52
+
53
+ memory = ConversationBufferMemory(
54
+ memory_key="chat_history",
55
+ output_key='answer',
56
+ return_messages=True
57
+ )
58
+
59
+ retriever = vector_db.as_retriever()
60
+ qa_chain = ConversationalRetrievalChain.from_llm(
61
+ llm,
62
+ retriever=retriever,
63
+ chain_type="stuff",
64
+ memory=memory,
65
+ return_source_documents=True,
66
+ verbose=False,
67
+ )
68
+ return qa_chain
69
+
70
+ # Function to handle chatbot responses
71
  def respond(
72
  message,
73
  history: list[tuple[str, str]],
 
75
  max_tokens,
76
  temperature,
77
  top_p,
78
+ vector_db,
79
+ llm_model,
80
  ):
81
+ # Initialize LLM chain if not already initialized
82
+ if not hasattr(respond, 'qa_chain'):
83
+ respond.qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_p, vector_db)
84
+
85
+ # Format chat history
86
+ formatted_chat_history = []
87
+ for user_message, bot_message in history:
88
+ formatted_chat_history.append(f"User: {user_message}")
89
+ formatted_chat_history.append(f"Assistant: {bot_message}")
90
+ formatted_chat_history.append(f"User: {message}")
91
+
92
+ # Generate response using QA chain
93
+ response = respond.qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
94
+ response_answer = response["answer"]
95
+ if response_answer.find("Helpful Answer:") != -1:
96
+ response_answer = response_answer.split("Helpful Answer:")[-1]
97
+
98
+ return response_answer
99
+
100
+ # CSS for styling the interface
101
+ css = """
102
+ body {
103
+ background-color: #06688E; /* Dark background */
104
+ color: white; /* Text color for better visibility */
105
+ }
106
+ .gr-button {
107
+ background-color: #42B3CE !important; /* White button color */
108
+ color: black !important; /* Black text for contrast */
109
+ border: none !important;
110
+ padding: 8px 16px !important;
111
+ border-radius: 5px !important;
112
+ }
113
+ .gr-button:hover {
114
+ background-color: #e0e0e0 !important; /* Slightly lighter button on hover */
115
+ }
116
+ .gr-slider-container {
117
+ color: white !important; /* Slider labels in white */
118
+ }
119
+ """
120
+
121
+ # Initialize database and LLM chain
122
+ def initialize_database_and_llm(list_file_obj, llm_option, max_tokens, temperature, top_p):
123
+ list_file_path = [x.name for x in list_file_obj if x is not None]
124
+ doc_splits = load_doc(list_file_path)
125
+ vector_db = create_db(doc_splits)
126
+ llm_name = list_llm[llm_option]
127
+ return vector_db, llm_name
128
+
129
+ # Gradio interface
130
+ demo = gr.ChatInterface(
131
+ respond,
132
+ additional_inputs=[
133
+ gr.Files(file_count="multiple", file_types=["pdf"], label="Upload PDF documents", visible=False),
134
+ gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple, visible=False),
135
+ gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max new tokens", visible=False),
136
+ gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", visible=False),
137
+ gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k", visible=False),
138
+ ],
139
+ css=css,
140
+ title="RAG PDF Chatbot",
141
+ description="Query your PDF documents using a Retrieval Augmented Generation (RAG) chatbot.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  )
143
 
144
+ # Preprocessing events
145
+ demo.preprocess(
146
+ initialize_database_and_llm,
147
+ inputs=["document", "llm_btn", "slider_maxtokens", "slider_temperature", "slider_topk"],
148
+ outputs=["vector_db", "llm_model"],
149
+ api_name="initialize",
 
 
 
 
 
 
150
  )
151
 
152
+ if __name__ == "__main__":
153
+ demo.launch(share=True)