Shreyas094 commited on
Commit
79f12fc
·
verified ·
1 Parent(s): fe1ab08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -367
app.py CHANGED
@@ -2,114 +2,21 @@ import os
2
  import json
3
  import re
4
  import gradio as gr
5
- import pandas as pd
6
  import requests
7
- import random
8
- import urllib.parse
9
- import spacy
10
- from sklearn.metrics.pairwise import cosine_similarity
11
- import numpy as np
12
- from typing import List, Dict
13
  from tempfile import NamedTemporaryFile
14
- from bs4 import BeautifulSoup
15
- from langchain.prompts import PromptTemplate
16
- from langchain.chains import LLMChain
17
- from langchain_core.prompts import ChatPromptTemplate
18
  from langchain_community.vectorstores import FAISS
19
  from langchain_community.document_loaders import PyPDFLoader
20
- from langchain_core.output_parsers import StrOutputParser
21
  from langchain_community.embeddings import HuggingFaceEmbeddings
22
- from langchain_community.llms import HuggingFaceHub
23
- from langchain_core.documents import Document
24
- from sentence_transformers import SentenceTransformer
25
  from llama_parse import LlamaParse
 
26
 
 
27
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
28
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
29
 
30
- # Load SentenceTransformer model
31
- sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
32
-
33
- def load_spacy_model():
34
- try:
35
- # Try to load the model
36
- return spacy.load("en_core_web_sm")
37
- except OSError:
38
- # If loading fails, download the model
39
- os.system("python -m spacy download en_core_web_sm")
40
- # Try loading again
41
- return spacy.load("en_core_web_sm")
42
-
43
- # Load spaCy model
44
- nlp = load_spacy_model()
45
-
46
- class EnhancedContextDrivenChatbot:
47
- def __init__(self, history_size=10):
48
- self.history = []
49
- self.history_size = history_size
50
- self.entity_tracker = {}
51
-
52
- def add_to_history(self, text):
53
- self.history.append(text)
54
- if len(self.history) > self.history_size:
55
- self.history.pop(0)
56
-
57
- # Update entity tracker
58
- doc = nlp(text)
59
- for ent in doc.ents:
60
- if ent.label_ not in self.entity_tracker:
61
- self.entity_tracker[ent.label_] = set()
62
- self.entity_tracker[ent.label_].add(ent.text)
63
-
64
- def get_context(self):
65
- return " ".join(self.history)
66
-
67
- def is_follow_up_question(self, question):
68
- doc = nlp(question.lower())
69
- follow_up_indicators = set(['it', 'this', 'that', 'these', 'those', 'he', 'she', 'they', 'them'])
70
- return any(token.text in follow_up_indicators for token in doc)
71
-
72
- def extract_topics(self, text):
73
- doc = nlp(text)
74
- return [chunk.text for chunk in doc.noun_chunks]
75
-
76
- def get_most_relevant_context(self, question):
77
- if not self.history:
78
- return question
79
-
80
- # Create a combined context from history
81
- combined_context = self.get_context()
82
-
83
- # Get embeddings
84
- context_embedding = sentence_model.encode([combined_context])[0]
85
- question_embedding = sentence_model.encode([question])[0]
86
-
87
- # Calculate similarity
88
- similarity = cosine_similarity([context_embedding], [question_embedding])[0][0]
89
-
90
- # If similarity is low, it might be a new topic
91
- if similarity < 0.3: # This threshold can be adjusted
92
- return question
93
-
94
- # Otherwise, prepend the context
95
- return f"{combined_context} {question}"
96
-
97
- def process_question(self, question):
98
- contextualized_question = self.get_most_relevant_context(question)
99
-
100
- # Extract topics from the question
101
- topics = self.extract_topics(question)
102
-
103
- # Check if it's a follow-up question
104
- if self.is_follow_up_question(question):
105
- # If it's a follow-up, make sure to include previous context
106
- contextualized_question = f"{self.get_context()} {question}"
107
-
108
- # Add the new question to history
109
- self.add_to_history(question)
110
-
111
- return contextualized_question, topics, self.entity_tracker
112
-
113
  # Initialize LlamaParse
114
  llama_parser = LlamaParse(
115
  api_key=llama_cloud_api_key,
@@ -136,6 +43,9 @@ def load_document(file: NamedTemporaryFile, parser: str = "pypdf") -> List[Docum
136
  else:
137
  raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
138
 
 
 
 
139
  def update_vectors(files, parser):
140
  if not files:
141
  return "Please upload at least one PDF file."
@@ -159,308 +69,173 @@ def update_vectors(files, parser):
159
 
160
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
161
 
162
- def get_embeddings():
163
- return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
164
-
165
- def clear_cache():
166
- if os.path.exists("faiss_database"):
167
- os.remove("faiss_database")
168
- return "Cache cleared successfully."
169
- else:
170
- return "No cache to clear."
171
-
172
- def get_model(temperature, top_p, repetition_penalty):
173
- return HuggingFaceHub(
174
- repo_id="mistralai/Mistral-7B-Instruct-v0.3",
175
- model_kwargs={
176
  "temperature": temperature,
177
- "top_p": top_p,
 
178
  "repetition_penalty": repetition_penalty,
179
- "max_length": 1000
180
- },
181
- huggingfacehub_api_token=huggingface_token
182
- )
183
-
184
- def generate_chunked_response(model, prompt, max_tokens=1000, max_chunks=5):
185
  full_response = ""
186
- for i in range(max_chunks):
187
- try:
188
- chunk = model(prompt + full_response, max_new_tokens=max_tokens)
189
- chunk = chunk.strip()
190
- if chunk.endswith((".", "!", "?")):
191
- full_response += chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  break
193
- full_response += chunk
194
- except Exception as e:
195
- print(f"Error in generate_chunked_response: {e}")
196
  break
197
- return full_response.strip()
198
-
199
- def extract_text_from_webpage(html):
200
- soup = BeautifulSoup(html, 'html.parser')
201
- for script in soup(["script", "style"]):
202
- script.extract()
203
- text = soup.get_text()
204
- lines = (line.strip() for line in text.splitlines())
205
- chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
206
- text = '\n'.join(chunk for chunk in chunks if chunk)
207
- return text
208
-
209
- _useragent_list = [
210
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
211
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
212
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
213
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
214
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
215
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
216
- ]
217
-
218
- def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
219
- escaped_term = urllib.parse.quote_plus(term)
220
- start = 0
221
- all_results = []
222
- max_chars_per_page = 8000
223
-
224
- print(f"Starting Google search for term: '{term}'")
225
-
226
- with requests.Session() as session:
227
- while start < num_results:
228
- try:
229
- user_agent = random.choice(_useragent_list)
230
- headers = {
231
- 'User-Agent': user_agent
232
- }
233
- resp = session.get(
234
- url="https://www.google.com/search",
235
- headers=headers,
236
- params={
237
- "q": term,
238
- "num": num_results - start,
239
- "hl": lang,
240
- "start": start,
241
- "safe": safe,
242
- },
243
- timeout=timeout,
244
- verify=ssl_verify,
245
- )
246
- resp.raise_for_status()
247
- print(f"Successfully retrieved search results page (start={start})")
248
- except requests.exceptions.RequestException as e:
249
- print(f"Error retrieving search results: {e}")
250
- break
251
-
252
- soup = BeautifulSoup(resp.text, "html.parser")
253
- result_block = soup.find_all("div", attrs={"class": "g"})
254
- if not result_block:
255
- print("No results found on this page")
256
- break
257
-
258
- print(f"Found {len(result_block)} results on this page")
259
- for result in result_block:
260
- link = result.find("a", href=True)
261
- if link:
262
- link = link["href"]
263
- print(f"Processing link: {link}")
264
- try:
265
- webpage = session.get(link, headers=headers, timeout=timeout)
266
- webpage.raise_for_status()
267
- visible_text = extract_text_from_webpage(webpage.text)
268
- if len(visible_text) > max_chars_per_page:
269
- visible_text = visible_text[:max_chars_per_page] + "..."
270
- all_results.append({"link": link, "text": visible_text})
271
- print(f"Successfully extracted text from {link}")
272
- except requests.exceptions.RequestException as e:
273
- print(f"Error retrieving webpage content: {e}")
274
- all_results.append({"link": link, "text": None})
275
- else:
276
- print("No link found for this result")
277
- all_results.append({"link": None, "text": None})
278
- start += len(result_block)
279
-
280
- print(f"Search completed. Total results: {len(all_results)}")
281
 
282
- if not all_results:
283
- print("No search results found. Returning a default message.")
284
- return [{"link": None, "text": "No information found in the web search results."}]
 
 
 
285
 
286
- return all_results
 
 
 
287
 
288
- def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot):
289
- if not question:
290
- return "Please enter a question."
 
 
291
 
292
- model = get_model(temperature, top_p, repetition_penalty)
293
  embed = get_embeddings()
294
-
295
  if os.path.exists("faiss_database"):
296
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
297
  else:
298
- database = None
299
-
300
- max_attempts = 3
301
- context_reduction_factor = 0.7
302
-
303
- if web_search:
304
- contextualized_question, topics, entity_tracker = chatbot.process_question(question)
305
- serializable_entity_tracker = {k: list(v) for k, v in entity_tracker.items()}
306
- search_results = google_search(contextualized_question)
307
- all_answers = []
308
-
309
- for attempt in range(max_attempts):
310
- try:
311
- web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
312
-
313
- if database is None:
314
- database = FAISS.from_documents(web_docs, embed)
315
- else:
316
- database.add_documents(web_docs)
317
-
318
- database.save_local("faiss_database")
319
-
320
- context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
321
-
322
- prompt_template = """
323
- Answer the question based on the following web search results, conversation context, and entity information:
324
- Web Search Results:
325
- {context}
326
- Conversation Context: {conv_context}
327
- Current Question: {question}
328
- Topics: {topics}
329
- Entity Information: {entities}
330
- If the web search results don't contain relevant information, state that the information is not available in the search results.
331
- Provide a summarized and direct answer to the question without mentioning the web search or these instructions.
332
- Do not include any source information in your answer.
333
- """
334
-
335
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
336
- formatted_prompt = prompt_val.format(
337
- context=context_str,
338
- conv_context=chatbot.get_context(),
339
- question=question,
340
- topics=", ".join(topics),
341
- entities=json.dumps(serializable_entity_tracker)
342
- )
343
-
344
- full_response = generate_chunked_response(model, formatted_prompt)
345
- answer = extract_answer(full_response)
346
- all_answers.append(answer)
347
- break
348
-
349
- except Exception as e:
350
- print(f"Error in ask_question (attempt {attempt + 1}): {e}")
351
- if attempt == max_attempts - 1:
352
- all_answers.append(f"I apologize, but I'm having trouble processing the query due to its length or complexity.")
353
-
354
- answer = "\n\n".join(all_answers)
355
- sources = set(doc.metadata['source'] for doc in web_docs)
356
- sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
357
- answer += sources_section
358
-
359
- return answer
360
-
361
- else: # PDF document chat
362
- for attempt in range(max_attempts):
363
- try:
364
- if database is None:
365
- return "No documents available. Please upload PDF documents to answer questions."
366
-
367
- retriever = database.as_retriever()
368
- relevant_docs = retriever.get_relevant_documents(question)
369
- context_str = "\n".join([doc.page_content for doc in relevant_docs])
370
-
371
- if attempt > 0:
372
- words = context_str.split()
373
- context_str = " ".join(words[:int(len(words) * context_reduction_factor)])
374
 
375
- prompt_template = """
376
- Answer the question based on the following context from the PDF document:
377
- Context:
378
- {context}
379
- Question: {question}
380
- If the context doesn't contain relevant information, state that the information is not available in the document.
381
- Provide a summarized and direct answer to the question.
382
- """
383
 
384
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
385
- formatted_prompt = prompt_val.format(context=context_str, question=question)
 
 
386
 
387
- full_response = generate_chunked_response(model, formatted_prompt)
388
- answer = extract_answer(full_response)
389
 
390
- return answer
 
 
391
 
392
- except Exception as e:
393
- print(f"Error in ask_question (attempt {attempt + 1}): {e}")
394
- if attempt == max_attempts - 1:
395
- return f"I apologize, but I'm having trouble processing your question. Could you please try rephrasing it more concisely?"
396
 
397
- return "An unexpected error occurred. Please try again later."
398
-
399
- def extract_answer(full_response):
400
- # First, try to split the response at common instruction phrases
401
- answer_patterns = [
402
- r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
403
- r"Provide a concise and direct answer to the question:",
404
- r"Answer:",
405
- r"Provide a summarized and direct answer to the question.",
406
- r"If the context doesn't contain relevant information, state that the information is not available in the document.",
407
- r"Provide a summarized and direct answer to the original question without mentioning the web search or these instructions:",
408
- r"Do not include any source information in your answer."
409
- ]
410
-
411
- for pattern in answer_patterns:
412
- match = re.split(pattern, full_response, flags=re.IGNORECASE)
413
- if len(match) > 1:
414
- full_response = match[-1].strip()
415
- break
416
 
417
- # Then, remove any remaining instruction-like phrases
418
- cleanup_patterns = [
419
- r"without mentioning the web search or these instructions\.",
420
- r"Do not include any source information in your answer\.",
421
- r"If the context doesn't contain relevant information, state that the information is not available in the document\."
422
- ]
423
-
424
- for pattern in cleanup_patterns:
425
- full_response = re.sub(pattern, "", full_response, flags=re.IGNORECASE).strip()
426
 
427
- return full_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  # Gradio interface
430
  with gr.Blocks() as demo:
431
- gr.Markdown("# Enhanced PDF Document Chat and Web Search")
432
 
433
  with gr.Row():
434
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
435
  parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="pypdf")
436
- update_button = gr.Button("Upload PDF")
437
 
438
  update_output = gr.Textbox(label="Update Status")
439
  update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
440
 
441
- with gr.Row():
442
- with gr.Column(scale=2):
443
- chatbot = gr.Chatbot(label="Conversation")
444
- question_input = gr.Textbox(label="Ask a question")
445
- submit_button = gr.Button("Submit")
446
- with gr.Column(scale=1):
447
- temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
448
- top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
449
- repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
450
- web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
451
-
452
- enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
453
-
454
- def chat(question, history, temperature, top_p, repetition_penalty, web_search):
455
- answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot)
456
- history.append((question, answer))
457
- return "", history
458
 
459
- submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox], outputs=[question_input, chatbot])
 
 
460
 
461
- clear_button = gr.Button("Clear Cache")
462
- clear_output = gr.Textbox(label="Cache Status")
463
- clear_button.click(clear_cache, inputs=[], outputs=clear_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
  if __name__ == "__main__":
466
- demo.launch()
 
2
  import json
3
  import re
4
  import gradio as gr
 
5
  import requests
6
+ from duckduckgo_search import DDGS
7
+ from typing import List
8
+ from pydantic import BaseModel, Field
 
 
 
9
  from tempfile import NamedTemporaryFile
 
 
 
 
10
  from langchain_community.vectorstores import FAISS
11
  from langchain_community.document_loaders import PyPDFLoader
 
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
13
  from llama_parse import LlamaParse
14
+ from langchain_core.documents import Document
15
 
16
+ # Environment variables and configurations
17
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
18
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Initialize LlamaParse
21
  llama_parser = LlamaParse(
22
  api_key=llama_cloud_api_key,
 
43
  else:
44
  raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
45
 
46
+ def get_embeddings():
47
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
48
+
49
  def update_vectors(files, parser):
50
  if not files:
51
  return "Please upload at least one PDF file."
 
69
 
70
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
71
 
72
+ def generate_chunked_response(prompt, max_tokens=1000, max_chunks=5, temperature=0.3, repetition_penalty=1.1):
73
+ API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct"
74
+ headers = {"Authorization": f"Bearer {huggingface_token}"}
75
+ payload = {
76
+ "inputs": prompt,
77
+ "parameters": {
78
+ "max_new_tokens": max_tokens,
 
 
 
 
 
 
 
79
  "temperature": temperature,
80
+ "top_p": 0.4,
81
+ "top_k": 40,
82
  "repetition_penalty": repetition_penalty,
83
+ "stop": ["</s>", "[/INST]"]
84
+ }
85
+ }
86
+
 
 
87
  full_response = ""
88
+ for _ in range(max_chunks):
89
+ response = requests.post(API_URL, headers=headers, json=payload)
90
+ if response.status_code == 200:
91
+ result = response.json()
92
+ if isinstance(result, list) and len(result) > 0:
93
+ chunk = result[0].get('generated_text', '')
94
+
95
+ # Remove any part of the chunk that's already in full_response
96
+ new_content = chunk[len(full_response):].strip()
97
+
98
+ if not new_content:
99
+ break # No new content, so we're done
100
+
101
+ full_response += new_content
102
+
103
+ if chunk.endswith((".", "!", "?", "</s>", "[/INST]")):
104
+ break
105
+
106
+ # Update the prompt for the next iteration
107
+ payload["inputs"] = full_response
108
+ else:
109
  break
110
+ else:
 
 
111
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ # Clean up the response
114
+ clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
115
+ clean_response = clean_response.replace("Using the following context:", "").strip()
116
+ clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
117
+
118
+ return clean_response
119
 
120
+ def duckduckgo_search(query):
121
+ with DDGS() as ddgs:
122
+ results = ddgs.text(query, max_results=5)
123
+ return results
124
 
125
+ class CitingSources(BaseModel):
126
+ sources: List[str] = Field(
127
+ ...,
128
+ description="List of sources to cite. Should be an URL of the source."
129
+ )
130
 
131
+ def get_response_from_pdf(query, temperature=0.7, repetition_penalty=1.1):
132
  embed = get_embeddings()
 
133
  if os.path.exists("faiss_database"):
134
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
135
  else:
136
+ return "No documents available. Please upload PDF documents to answer questions."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ retriever = database.as_retriever()
139
+ relevant_docs = retriever.get_relevant_documents(query)
140
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
 
 
 
 
 
141
 
142
+ prompt = f"""<s>[INST] Using the following context from the PDF documents:
143
+ {context_str}
144
+ Write a detailed and complete response that answers the following user question: '{query}'
145
+ Do not include a list of sources in your response. [/INST]"""
146
 
147
+ generated_text = generate_chunked_response(prompt, temperature=temperature, repetition_penalty=repetition_penalty)
 
148
 
149
+ # Clean the response
150
+ clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
151
+ clean_text = clean_text.replace("Using the following context from the PDF documents:", "").strip()
152
 
153
+ return clean_text
 
 
 
154
 
155
+ def get_response_with_search(query, temperature=0.7, repetition_penalty=1.1):
156
+ search_results = duckduckgo_search(query)
157
+ context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
158
+ for result in search_results if 'body' in result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ prompt = f"""<s>[INST] Using the following context:
161
+ {context}
162
+ Write a detailed and complete research document that fulfills the following user request: '{query}'
163
+ After writing the document, please provide a list of sources used in your response. [/INST]"""
164
+
165
+ generated_text = generate_chunked_response(prompt, temperature=temperature, repetition_penalty=repetition_penalty)
 
 
 
166
 
167
+ # Clean the response
168
+ clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
169
+ clean_text = clean_text.replace("Using the following context:", "").strip()
170
+
171
+ # Split the content and sources
172
+ parts = clean_text.split("Sources:", 1)
173
+ main_content = parts[0].strip()
174
+ sources = parts[1].strip() if len(parts) > 1 else ""
175
+
176
+ return main_content, sources
177
+
178
+ def chatbot_interface(message, history, use_web_search, temperature, repetition_penalty):
179
+ if use_web_search:
180
+ main_content, sources = get_response_with_search(message, temperature, repetition_penalty)
181
+ formatted_response = f"{main_content}\n\nSources:\n{sources}"
182
+ else:
183
+ response = get_response_from_pdf(message, temperature, repetition_penalty)
184
+ formatted_response = response
185
+
186
+ history.append((message, formatted_response))
187
+ return history
188
 
189
  # Gradio interface
190
  with gr.Blocks() as demo:
191
+ gr.Markdown("# AI-powered Web Search and PDF Chat Assistant")
192
 
193
  with gr.Row():
194
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
195
  parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="pypdf")
196
+ update_button = gr.Button("Upload Document")
197
 
198
  update_output = gr.Textbox(label="Update Status")
199
  update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
200
 
201
+ chatbot = gr.Chatbot(label="Conversation")
202
+ msg = gr.Textbox(label="Ask a question")
203
+ use_web_search = gr.Checkbox(label="Use Web Search", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ with gr.Row():
206
+ temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
207
+ repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
208
 
209
+ submit = gr.Button("Submit")
210
+
211
+ gr.Examples(
212
+ examples=[
213
+ ["What are the latest developments in AI?"],
214
+ ["Tell me about recent updates on GitHub"],
215
+ ["What are the best hotels in Galapagos, Ecuador?"],
216
+ ["Summarize recent advancements in Python programming"],
217
+ ],
218
+ inputs=msg,
219
+ )
220
+
221
+ submit.click(chatbot_interface,
222
+ inputs=[msg, chatbot, use_web_search, temperature_slider, repetition_penalty_slider],
223
+ outputs=[chatbot])
224
+ msg.submit(chatbot_interface,
225
+ inputs=[msg, chatbot, use_web_search, temperature_slider, repetition_penalty_slider],
226
+ outputs=[chatbot])
227
+
228
+ gr.Markdown(
229
+ """
230
+ ## How to use
231
+ 1. Upload PDF documents using the file input at the top.
232
+ 2. Select the PDF parser (pypdf or llamaparse) and click "Upload Document" to update the vector store.
233
+ 3. Ask questions in the textbox.
234
+ 4. Toggle "Use Web Search" to switch between PDF chat and web search.
235
+ 5. Adjust Temperature and Repetition Penalty sliders to fine-tune the response generation.
236
+ 6. Click "Submit" or press Enter to get a response.
237
+ """
238
+ )
239
 
240
  if __name__ == "__main__":
241
+ demo.launch(share=True)