Shreyas094 commited on
Commit
781b94b
1 Parent(s): 7f18930

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -76
app.py CHANGED
@@ -1,102 +1,170 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceApi
3
- from duckduckgo_search import DDGS
4
  import requests
5
- import json
 
 
 
6
  from typing import List
7
  from pydantic import BaseModel, Field
 
 
 
 
 
 
 
 
 
8
 
9
- # Global variables
10
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Function to perform a DuckDuckGo search
13
  def duckduckgo_search(query):
14
  with DDGS() as ddgs:
15
  results = ddgs.text(query, max_results=5)
16
  return results
17
 
18
- class CitingSources(BaseModel):
19
- sources: List[str] = Field(
20
- ...,
21
- description="List of sources to cite. Should be an URL of the source."
22
- )
 
 
 
 
 
 
 
 
23
 
24
- def get_response_with_search(query):
25
- # Perform the web search
26
- search_results = duckduckgo_search(query)
27
-
28
- # Use the search results as context for the model
29
- context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
30
- for result in search_results if 'body' in result)
31
-
32
- # Prompt formatted for Mistral-7B-Instruct
33
  prompt = f"""<s>[INST] Using the following context:
34
  {context}
35
  Write a detailed and complete research document that fulfills the following user request: '{query}'
36
  After writing the document, please provide a list of sources used in your response. [/INST]"""
 
 
37
 
38
- # API endpoint for Mistral-7B-Instruct-v0.3
39
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
40
 
41
- # Headers
42
- headers = {"Authorization": f"Bearer {huggingface_token}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Payload
45
- payload = {
46
- "inputs": prompt,
47
- "parameters": {
48
- "max_new_tokens": 1000,
49
- "temperature": 0.7,
50
- "top_p": 0.95,
51
- "top_k": 40,
52
- "repetition_penalty": 1.1
53
- }
54
- }
55
 
56
- # Make the API call
57
- response = requests.post(API_URL, headers=headers, json=payload)
58
 
59
- if response.status_code == 200:
60
- result = response.json()
61
- if isinstance(result, list) and len(result) > 0:
62
- generated_text = result[0].get('generated_text', 'No text generated')
63
-
64
- # Remove the instruction part
65
- content_start = generated_text.find("[/INST]")
66
- if content_start != -1:
67
- generated_text = generated_text[content_start + 7:].strip()
68
-
69
- # Split the response into main content and sources
70
- parts = generated_text.split("Sources:", 1)
71
- main_content = parts[0].strip()
72
- sources = parts[1].strip() if len(parts) > 1 else ""
73
-
74
- return main_content, sources
75
- else:
76
- return f"Unexpected response format: {result}", ""
77
- else:
78
- return f"Error: API returned status code {response.status_code}", ""
79
 
80
- def chatbot_interface(message, history):
81
- main_content, sources = get_response_with_search(message)
82
- formatted_response = f"{main_content}\n\nSources:\n{sources}"
83
- return formatted_response
84
 
85
- # Gradio chatbot interface
86
- iface = gr.ChatInterface(
87
- fn=chatbot_interface,
88
- title="AI-powered Web Search Assistant",
89
- description="Ask questions, and I'll search the web and provide answers using the Mistral-7B-Instruct model.",
90
- examples=[
91
- ["What are the latest developments in AI?"],
92
- ["Tell me about recent updates on GitHub"],
93
- ["What are the best hotels in Galapagos, Ecuador?"],
94
- ["Summarize recent advancements in Python programming"],
95
- ],
96
- retry_btn="Retry",
97
- undo_btn="Undo",
98
- clear_btn="Clear",
99
- )
100
 
101
  if __name__ == "__main__":
102
- iface.launch()
 
1
+ import os
2
+ import json
3
+ import re
4
  import gradio as gr
 
 
5
  import requests
6
+ import random
7
+ import urllib.parse
8
+ from tempfile import NamedTemporaryFile
9
+ from bs4 import BeautifulSoup
10
  from typing import List
11
  from pydantic import BaseModel, Field
12
+ from huggingface_hub import InferenceApi
13
+ from duckduckgo_search import DDGS
14
+ from langchain_community.vectorstores import FAISS
15
+ from langchain_community.document_loaders import PyPDFLoader
16
+ from langchain_community.embeddings import HuggingFaceEmbeddings
17
+ from langchain_community.llms import HuggingFaceHub
18
+ from langchain_core.documents import Document
19
+ from sentence_transformers import SentenceTransformer
20
+ from llama_parse import LlamaParse
21
 
22
+ # Environment variables and configurations
23
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
24
+ llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
25
+
26
+ # Initialize SentenceTransformer and LlamaParse
27
+ sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
28
+ llama_parser = LlamaParse(
29
+ api_key=llama_cloud_api_key,
30
+ result_type="markdown",
31
+ num_workers=4,
32
+ verbose=True,
33
+ language="en",
34
+ )
35
+
36
+ def load_document(file: NamedTemporaryFile, parser: str = "pypdf") -> List[Document]:
37
+ if parser == "pypdf":
38
+ loader = PyPDFLoader(file.name)
39
+ return loader.load_and_split()
40
+ elif parser == "llamaparse":
41
+ try:
42
+ documents = llama_parser.load_data(file.name)
43
+ return [Document(page_content=doc.text, metadata={"source": file.name}) for doc in documents]
44
+ except Exception as e:
45
+ print(f"Error using Llama Parse: {str(e)}")
46
+ print("Falling back to PyPDF parser")
47
+ loader = PyPDFLoader(file.name)
48
+ return loader.load_and_split()
49
+ else:
50
+ raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
51
+
52
+ def update_vectors(files, parser):
53
+ if not files:
54
+ return "Please upload at least one PDF file."
55
+
56
+ embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
57
+ total_chunks = 0
58
+
59
+ all_data = []
60
+ for file in files:
61
+ data = load_document(file, parser)
62
+ all_data.extend(data)
63
+ total_chunks += len(data)
64
+
65
+ if os.path.exists("faiss_database"):
66
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
67
+ database.add_documents(all_data)
68
+ else:
69
+ database = FAISS.from_documents(all_data, embed)
70
+
71
+ database.save_local("faiss_database")
72
+
73
+ return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
74
+
75
+ def clear_cache():
76
+ if os.path.exists("faiss_database"):
77
+ os.remove("faiss_database")
78
+ return "Cache cleared successfully."
79
+ else:
80
+ return "No cache to clear."
81
+
82
+ def get_model(temperature, top_p, repetition_penalty):
83
+ return HuggingFaceHub(
84
+ repo_id="mistralai/Mistral-7B-Instruct-v0.3",
85
+ model_kwargs={
86
+ "temperature": temperature,
87
+ "top_p": top_p,
88
+ "repetition_penalty": repetition_penalty,
89
+ "max_length": 1000
90
+ },
91
+ huggingfacehub_api_token=huggingface_token
92
+ )
93
 
 
94
  def duckduckgo_search(query):
95
  with DDGS() as ddgs:
96
  results = ddgs.text(query, max_results=5)
97
  return results
98
 
99
+ def get_response_with_search(query, temperature, top_p, repetition_penalty, use_pdf=False):
100
+ model = get_model(temperature, top_p, repetition_penalty)
101
+ embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
102
+
103
+ if use_pdf and os.path.exists("faiss_database"):
104
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
105
+ retriever = database.as_retriever()
106
+ relevant_docs = retriever.get_relevant_documents(query)
107
+ context = "\n".join([f"Content: {doc.page_content}\nSource: {doc.metadata['source']}\n" for doc in relevant_docs])
108
+ else:
109
+ search_results = duckduckgo_search(query)
110
+ context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
111
+ for result in search_results if 'body' in result)
112
 
 
 
 
 
 
 
 
 
 
113
  prompt = f"""<s>[INST] Using the following context:
114
  {context}
115
  Write a detailed and complete research document that fulfills the following user request: '{query}'
116
  After writing the document, please provide a list of sources used in your response. [/INST]"""
117
+
118
+ response = model(prompt)
119
 
120
+ main_content, sources = split_response(response)
 
121
 
122
+ return main_content, sources
123
+
124
+ def split_response(response):
125
+ parts = response.split("Sources:", 1)
126
+ main_content = parts[0].strip()
127
+ sources = parts[1].strip() if len(parts) > 1 else ""
128
+ return main_content, sources
129
+
130
+ def chatbot_interface(message, history, temperature, top_p, repetition_penalty, use_pdf):
131
+ main_content, sources = get_response_with_search(message, temperature, top_p, repetition_penalty, use_pdf)
132
+ formatted_response = f"{main_content}\n\nSources:\n{sources}"
133
+ return formatted_response
134
+
135
+ # Gradio interface
136
+ with gr.Blocks() as demo:
137
+ gr.Markdown("# AI-powered Web Search and PDF Chat Assistant")
138
 
139
+ with gr.Row():
140
+ file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
141
+ parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="pypdf")
142
+ update_button = gr.Button("Upload PDF")
 
 
 
 
 
 
 
143
 
144
+ update_output = gr.Textbox(label="Update Status")
145
+ update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
146
 
147
+ with gr.Row():
148
+ with gr.Column(scale=2):
149
+ chatbot = gr.Chatbot(label="Conversation")
150
+ msg = gr.Textbox(label="Ask a question")
151
+ submit_button = gr.Button("Submit")
152
+ with gr.Column(scale=1):
153
+ temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.7, step=0.1)
154
+ top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05)
155
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.1, step=0.1)
156
+ use_pdf = gr.Checkbox(label="Use PDF Documents", value=False)
 
 
 
 
 
 
 
 
 
 
157
 
158
+ def respond(message, chat_history, temperature, top_p, repetition_penalty, use_pdf):
159
+ bot_message = chatbot_interface(message, chat_history, temperature, top_p, repetition_penalty, use_pdf)
160
+ chat_history.append((message, bot_message))
161
+ return "", chat_history
162
 
163
+ submit_button.click(respond, inputs=[msg, chatbot, temperature, top_p, repetition_penalty, use_pdf], outputs=[msg, chatbot])
164
+
165
+ clear_button = gr.Button("Clear Cache")
166
+ clear_output = gr.Textbox(label="Cache Status")
167
+ clear_button.click(clear_cache, inputs=[], outputs=clear_output)
 
 
 
 
 
 
 
 
 
 
168
 
169
  if __name__ == "__main__":
170
+ demo.launch()