Shreyas094 commited on
Commit
a2c0e0e
1 Parent(s): 790409e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -180
app.py CHANGED
@@ -1,203 +1,103 @@
1
- import os
2
- import json
3
- import re
4
  import gradio as gr
 
 
5
  import requests
6
- import random
7
- import urllib.parse
8
- from tempfile import NamedTemporaryFile
9
- from bs4 import BeautifulSoup
10
  from typing import List
11
  from pydantic import BaseModel, Field
12
- from huggingface_hub import InferenceApi
13
- from duckduckgo_search import DDGS
14
- from langchain_community.vectorstores import FAISS
15
- from langchain_community.document_loaders import PyPDFLoader
16
- from langchain_community.embeddings import HuggingFaceEmbeddings
17
- from langchain_community.llms import HuggingFaceHub
18
- from langchain_core.documents import Document
19
- from sentence_transformers import SentenceTransformer
20
- from llama_parse import LlamaParse
21
- import logging
22
-
23
- # Set up logging
24
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
25
 
26
  # Environment variables and configurations
27
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
28
- llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
29
-
30
- # Initialize SentenceTransformer and LlamaParse
31
- sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
32
- llama_parser = LlamaParse(
33
- api_key=llama_cloud_api_key,
34
- result_type="markdown",
35
- num_workers=4,
36
- verbose=True,
37
- language="en",
38
- )
39
-
40
- def load_document(file: NamedTemporaryFile, parser: str = "pypdf") -> List[Document]:
41
- if parser == "pypdf":
42
- loader = PyPDFLoader(file.name)
43
- return loader.load_and_split()
44
- elif parser == "llamaparse":
45
- try:
46
- documents = llama_parser.load_data(file.name)
47
- return [Document(page_content=doc.text, metadata={"source": file.name}) for doc in documents]
48
- except Exception as e:
49
- print(f"Error using Llama Parse: {str(e)}")
50
- print("Falling back to PyPDF parser")
51
- loader = PyPDFLoader(file.name)
52
- return loader.load_and_split()
53
- else:
54
- raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
55
-
56
- def update_vectors(files, parser):
57
- if not files:
58
- return "Please upload at least one PDF file."
59
-
60
- embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
61
- total_chunks = 0
62
-
63
- all_data = []
64
- for file in files:
65
- data = load_document(file, parser)
66
- all_data.extend(data)
67
- total_chunks += len(data)
68
-
69
- if os.path.exists("faiss_database"):
70
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
71
- database.add_documents(all_data)
72
- else:
73
- database = FAISS.from_documents(all_data, embed)
74
-
75
- database.save_local("faiss_database")
76
-
77
- return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
78
-
79
- def clear_cache():
80
- if os.path.exists("faiss_database"):
81
- os.remove("faiss_database")
82
- return "Cache cleared successfully."
83
- else:
84
- return "No cache to clear."
85
-
86
- def get_model(temperature, top_p, repetition_penalty):
87
- return HuggingFaceHub(
88
- repo_id="mistralai/Mistral-7B-Instruct-v0.3",
89
- model_kwargs={
90
- "temperature": temperature,
91
- "top_p": top_p,
92
- "repetition_penalty": repetition_penalty,
93
- "max_length": 1000
94
- },
95
- huggingfacehub_api_token=huggingface_token
96
- )
97
 
 
98
  def duckduckgo_search(query):
99
- logging.debug(f"Performing DuckDuckGo search for query: {query}")
100
  with DDGS() as ddgs:
101
- results = list(ddgs.text(query, max_results=5))
102
- logging.debug(f"Search returned {len(results)} results")
103
  return results
104
 
105
- def get_response_with_search(query, temperature, top_p, repetition_penalty, use_pdf=False):
106
- logging.debug(f"Getting response for query: {query}")
107
- logging.debug(f"Parameters: temperature={temperature}, top_p={top_p}, repetition_penalty={repetition_penalty}, use_pdf={use_pdf}")
108
-
109
- model = get_model(temperature, top_p, repetition_penalty)
110
- embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
111
 
112
- if use_pdf and os.path.exists("faiss_database"):
113
- logging.debug("Using PDF database for context")
114
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
115
- retriever = database.as_retriever()
116
- relevant_docs = retriever.get_relevant_documents(query)
117
- context = "\n".join([f"Content: {doc.page_content}\nSource: {doc.metadata['source']}\n" for doc in relevant_docs])
118
- else:
119
- logging.debug("Using web search for context")
120
- search_results = duckduckgo_search(query)
121
- context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
122
- for result in search_results if 'body' in result)
123
 
124
- logging.debug(f"Context generated. Length: {len(context)} characters")
125
-
 
 
 
126
  prompt = f"""<s>[INST] Using the following context:
127
  {context}
128
  Write a detailed and complete research document that fulfills the following user request: '{query}'
129
- After the main content, provide a list of sources used in your response, prefixed with 'Sources:'.
130
- Do not include any part of these instructions in your response. [/INST]"""
131
-
132
- logging.debug("Sending prompt to model")
133
- response = model(prompt)
134
- logging.debug(f"Received response from model. Length: {len(response)} characters")
135
-
136
- main_content, sources = split_response(response)
137
-
138
- logging.debug(f"Split response. Main content length: {len(main_content)}, Sources length: {len(sources)}")
139
- return main_content, sources
140
-
141
- def split_response(response):
142
- logging.debug("Splitting response")
143
- logging.debug(f"Original response: {response[:100]}...") # Log first 100 characters
144
-
145
- # Remove any remaining instruction text
146
- response = re.sub(r'\[/?INST\]', '', response)
147
- response = re.sub(r'~~.*?~~', '', response, flags=re.DOTALL)
148
-
149
- logging.debug(f"After removing instructions: {response[:100]}...") # Log first 100 characters
150
-
151
- # Split the response into main content and sources
152
- parts = response.split("Sources:", 1)
153
- main_content = parts[0].strip()
154
- sources = parts[1].strip() if len(parts) > 1 else ""
155
-
156
- logging.debug(f"Main content starts with: {main_content[:100]}...") # Log first 100 characters
157
- logging.debug(f"Sources: {sources[:100]}...") # Log first 100 characters
158
-
159
- return main_content, sources
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- def chatbot_interface(message, history, temperature, top_p, repetition_penalty, use_pdf):
162
- logging.debug(f"Chatbot interface called with message: {message}")
163
- main_content, sources = get_response_with_search(message, temperature, top_p, repetition_penalty, use_pdf)
164
  formatted_response = f"{main_content}\n\nSources:\n{sources}"
165
- logging.debug(f"Formatted response: {formatted_response[:100]}...") # Log first 100 characters
166
  return formatted_response
167
 
168
- # Gradio interface
169
- with gr.Blocks() as demo:
170
- gr.Markdown("# AI-powered Web Search and PDF Chat Assistant")
171
-
172
- with gr.Row():
173
- file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
174
- parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="pypdf")
175
- update_button = gr.Button("Upload PDF")
176
-
177
- update_output = gr.Textbox(label="Update Status")
178
- update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
179
-
180
- with gr.Row():
181
- with gr.Column(scale=2):
182
- chatbot = gr.Chatbot(label="Conversation")
183
- msg = gr.Textbox(label="Ask a question")
184
- submit_button = gr.Button("Submit")
185
- with gr.Column(scale=1):
186
- temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.7, step=0.1)
187
- top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05)
188
- repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.1, step=0.1)
189
- use_pdf = gr.Checkbox(label="Use PDF Documents", value=False)
190
-
191
- def respond(message, chat_history, temperature, top_p, repetition_penalty, use_pdf):
192
- bot_message = chatbot_interface(message, chat_history, temperature, top_p, repetition_penalty, use_pdf)
193
- chat_history.append((message, bot_message))
194
- return "", chat_history
195
-
196
- submit_button.click(respond, inputs=[msg, chatbot, temperature, top_p, repetition_penalty, use_pdf], outputs=[msg, chatbot])
197
-
198
- clear_button = gr.Button("Clear Cache")
199
- clear_output = gr.Textbox(label="Cache Status")
200
- clear_button.click(clear_cache, inputs=[], outputs=clear_output)
201
 
202
  if __name__ == "__main__":
203
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceApi
3
+ from duckduckgo_search import DDGS
4
  import requests
5
+ import json
 
 
 
6
  from typing import List
7
  from pydantic import BaseModel, Field
8
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Environment variables and configurations
11
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Function to perform a DuckDuckGo search
14
  def duckduckgo_search(query):
 
15
  with DDGS() as ddgs:
16
+ results = ddgs.text(query, max_results=5)
 
17
  return results
18
 
19
+ class CitingSources(BaseModel):
20
+ sources: List[str] = Field(
21
+ ...,
22
+ description="List of sources to cite. Should be an URL of the source."
23
+ )
 
24
 
25
+ def get_response_with_search(query):
26
+ # Perform the web search
27
+ search_results = duckduckgo_search(query)
 
 
 
 
 
 
 
 
28
 
29
+ # Use the search results as context for the model
30
+ context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
31
+ for result in search_results if 'body' in result)
32
+
33
+ # Prompt formatted for Mistral-7B-Instruct
34
  prompt = f"""<s>[INST] Using the following context:
35
  {context}
36
  Write a detailed and complete research document that fulfills the following user request: '{query}'
37
+ After writing the document, please provide a list of sources used in your response. [/INST]"""
38
+
39
+ # API endpoint for Mistral-7B-Instruct-v0.3
40
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
41
+
42
+ # Headers
43
+ headers = {"Authorization": f"Bearer {huggingface_token}"}
44
+
45
+ # Payload
46
+ payload = {
47
+ "inputs": prompt,
48
+ "parameters": {
49
+ "max_new_tokens": 1000,
50
+ "temperature": 0.7,
51
+ "top_p": 0.95,
52
+ "top_k": 40,
53
+ "repetition_penalty": 1.1
54
+ }
55
+ }
56
+
57
+ # Make the API call
58
+ response = requests.post(API_URL, headers=headers, json=payload)
59
+
60
+ if response.status_code == 200:
61
+ result = response.json()
62
+ if isinstance(result, list) and len(result) > 0:
63
+ generated_text = result[0].get('generated_text', 'No text generated')
64
+
65
+ # Remove the instruction part
66
+ content_start = generated_text.find("[/INST]")
67
+ if content_start != -1:
68
+ generated_text = generated_text[content_start + 7:].strip()
69
+
70
+ # Split the response into main content and sources
71
+ parts = generated_text.split("Sources:", 1)
72
+ main_content = parts[0].strip()
73
+ sources = parts[1].strip() if len(parts) > 1 else ""
74
+
75
+ return main_content, sources
76
+ else:
77
+ return f"Unexpected response format: {result}", ""
78
+ else:
79
+ return f"Error: API returned status code {response.status_code}", ""
80
 
81
+ def chatbot_interface(message, history):
82
+ main_content, sources = get_response_with_search(message)
 
83
  formatted_response = f"{main_content}\n\nSources:\n{sources}"
 
84
  return formatted_response
85
 
86
+ # Gradio chatbot interface
87
+ iface = gr.ChatInterface(
88
+ fn=chatbot_interface,
89
+ title="AI-powered Web Search Assistant",
90
+ description="Ask questions, and I'll search the web and provide answers using the Mistral-7B-Instruct model.",
91
+ examples=[
92
+ ["What are the latest developments in AI?"],
93
+ ["Tell me about recent updates on GitHub"],
94
+ ["What are the best hotels in Galapagos, Ecuador?"],
95
+ ["Summarize recent advancements in Python programming"],
96
+ ],
97
+ retry_btn="Retry",
98
+ undo_btn="Undo",
99
+ clear_btn="Clear",
100
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
+ iface.launch(share=True)