Shreyas094 commited on
Commit
687c2f0
1 Parent(s): 348631b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -102
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import gradio as gr
4
  import pandas as pd
5
  from tempfile import NamedTemporaryFile
6
-
7
  from langchain_core.prompts import ChatPromptTemplate
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_community.document_loaders import PyPDFLoader
@@ -11,119 +11,125 @@ from langchain_core.output_parsers import StrOutputParser
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
12
  from langchain_community.llms import HuggingFaceHub
13
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
14
-
 
15
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
16
-
17
- def load_and_split_document(file):
18
- """Loads and splits the document into pages."""
19
- loader = PyPDFLoader(file.name)
20
- data = loader.load_and_split()
21
- return data
22
-
 
 
 
 
 
 
 
 
 
23
  def get_embeddings():
24
- return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
25
-
26
  def create_or_update_database(data, embeddings):
27
- if os.path.exists("faiss_database"):
28
- db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
29
- db.add_documents(data)
30
- else:
31
- db = FAISS.from_documents(data, embeddings)
32
- db.save_local("faiss_database")
33
-
 
 
 
 
 
34
  prompt = """
35
  Answer the question based only on the following context:
36
  {context}
37
  Question: {question}
38
-
39
  Provide a concise and direct answer to the question:
40
  """
41
-
42
- def get_model():
43
- return HuggingFaceHub(
44
- repo_id="mistralai/Mistral-7B-Instruct-v0.3",
45
- model_kwargs={"temperature": 0.2, "max_length": 512},
46
- huggingfacehub_api_token=huggingface_token
47
- )
48
-
 
 
 
49
  def generate_chunked_response(model, prompt, max_tokens=500, max_chunks=5):
50
- full_response = ""
51
- for i in range(max_chunks):
52
- chunk = model(prompt + full_response, max_new_tokens=max_tokens)
53
- full_response += chunk
54
- if chunk.strip().endswith((".", "!", "?")):
55
- break
56
- return full_response.strip()
57
-
58
  def response(database, model, question):
59
- prompt_val = ChatPromptTemplate.from_template(prompt)
60
- retriever = database.as_retriever()
61
-
62
- context = retriever.get_relevant_documents(question)
63
- context_str = "\n".join([doc.page_content for doc in context])
64
-
65
- formatted_prompt = prompt_val.format(context=context_str, question=question)
66
-
67
- ans = generate_chunked_response(model, formatted_prompt)
68
- return ans
69
-
70
- def update_vectors(files):
71
- if not files:
72
- return "Please upload at least one PDF file."
73
-
74
- embed = get_embeddings()
75
- total_chunks = 0
76
-
77
- for file in files:
78
- data = load_and_split_document(file)
79
- create_or_update_database(data, embed)
80
- total_chunks += len(data)
81
-
82
- return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
83
-
84
- def ask_question(question):
85
- if not question:
86
- return "Please enter a question."
87
- embed = get_embeddings()
88
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
89
- model = get_model()
90
- return response(database, model, question)
91
-
92
  def extract_db_to_excel():
93
- embed = get_embeddings()
94
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
95
-
96
- documents = database.docstore._dict.values()
97
- data = [{"page_content": doc.page_content, "metadata": json.dumps(doc.metadata)} for doc in documents]
98
- df = pd.DataFrame(data)
99
-
100
- with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp:
101
- excel_path = tmp.name
102
- df.to_excel(excel_path, index=False)
103
-
104
- return excel_path
105
-
106
  # Gradio interface
107
  with gr.Blocks() as demo:
108
- gr.Markdown("# Chat with your PDF documents")
109
-
110
- with gr.Row():
111
- file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
112
- update_button = gr.Button("Update Vector Store")
113
-
114
- update_output = gr.Textbox(label="Update Status")
115
- update_button.click(update_vectors, inputs=[file_input], outputs=update_output)
116
-
117
- with gr.Row():
118
- question_input = gr.Textbox(label="Ask a question about your documents")
119
- submit_button = gr.Button("Submit")
120
-
121
- answer_output = gr.Textbox(label="Answer")
122
- submit_button.click(ask_question, inputs=[question_input], outputs=answer_output)
123
-
124
- extract_button = gr.Button("Extract Database to Excel")
125
- excel_output = gr.File(label="Download Excel File")
126
- extract_button.click(extract_db_to_excel, inputs=[], outputs=excel_output)
127
-
 
128
  if __name__ == "__main__":
129
- demo.launch()
 
3
  import gradio as gr
4
  import pandas as pd
5
  from tempfile import NamedTemporaryFile
6
+ from typing import List
7
  from langchain_core.prompts import ChatPromptTemplate
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_community.document_loaders import PyPDFLoader
 
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
12
  from langchain_community.llms import HuggingFaceHub
13
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
14
+ from langchain_core.text_splitters import RecursiveCharacterTextSplitter
15
+ from langchain_core.document import Document
16
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
17
+ def load_and_split_document_basic(file):
18
+ """Loads and splits the document into pages."""
19
+ loader = PyPDFLoader(file.name)
20
+ data = loader.load_and_split()
21
+ return data
22
+ def load_and_split_document_recursive(file: NamedTemporaryFile) -> List[Document]:
23
+ """Loads and splits the document into chunks."""
24
+ loader = PyPDFLoader(file.name)
25
+ pages = loader.load()
26
+ text_splitter = RecursiveCharacterTextSplitter(
27
+ chunk_size=1000,
28
+ chunk_overlap=200,
29
+ length_function=len,
30
+ )
31
+ chunks = text_splitter.split_documents(pages)
32
+ return chunks
33
  def get_embeddings():
34
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
35
  def create_or_update_database(data, embeddings):
36
+ if os.path.exists("faiss_database"):
37
+ db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
38
+ db.add_documents(data)
39
+ else:
40
+ db = FAISS.from_documents(data, embeddings)
41
+ db.save_local("faiss_database")
42
+ def clear_cache():
43
+ if os.path.exists("faiss_database"):
44
+ os.remove("faiss_database")
45
+ return "Cache cleared successfully."
46
+ else:
47
+ return "No cache to clear."
48
  prompt = """
49
  Answer the question based only on the following context:
50
  {context}
51
  Question: {question}
 
52
  Provide a concise and direct answer to the question:
53
  """
54
+ def get_model(temperature, top_p, repetition_penalty):
55
+ return HuggingFaceHub(
56
+ repo_id="mistralai/Mistral-7B-Instruct-v0.3",
57
+ model_kwargs={
58
+ "temperature": temperature,
59
+ "top_p": top_p,
60
+ "repetition_penalty": repetition_penalty,
61
+ "max_length": 512
62
+ },
63
+ huggingfacehub_api_token=huggingface_token
64
+ )
65
  def generate_chunked_response(model, prompt, max_tokens=500, max_chunks=5):
66
+ full_response = ""
67
+ for i in range(max_chunks):
68
+ chunk = model(prompt + full_response, max_new_tokens=max_tokens)
69
+ full_response += chunk
70
+ if chunk.strip().endswith((".", "!", "?")):
71
+ break
72
+ return full_response.strip()
 
73
  def response(database, model, question):
74
+ prompt_val = ChatPromptTemplate.from_template(prompt)
75
+ retriever = database.as_retriever()
76
+ context = retriever.get_relevant_documents(question)
77
+ context_str = "\n".join([doc.page_content for doc in context])
78
+ formatted_prompt = prompt_val.format(context=context_str, question=question)
79
+ ans = generate_chunked_response(model, formatted_prompt)
80
+ return ans
81
+ def update_vectors(files, use_recursive_splitter):
82
+ if not files:
83
+ return "Please upload at least one PDF file."
84
+ embed = get_embeddings()
85
+ total_chunks = 0
86
+ for file in files:
87
+ if use_recursive_splitter:
88
+ data = load_and_split_document_recursive(file)
89
+ else:
90
+ data = load_and_split_document_basic(file)
91
+ create_or_update_database(data, embed)
92
+ total_chunks += len(data)
93
+ return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
94
+ def ask_question(question, temperature, top_p, repetition_penalty):
95
+ if not question:
96
+ return "Please enter a question."
97
+ embed = get_embeddings()
98
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
99
+ model = get_model(temperature, top_p, repetition_penalty)
100
+ return response(database, model, question)
 
 
 
 
 
 
101
  def extract_db_to_excel():
102
+ embed = get_embeddings()
103
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
104
+ documents = database.docstore._dict.values()
105
+ data = [{"page_content": doc.page_content, "metadata": json.dumps(doc.metadata)} for doc in documents]
106
+ df = pd.DataFrame(data)
107
+ with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp:
108
+ excel_path = tmp.name
109
+ df.to_excel(excel_path, index=False)
110
+ return excel_path
 
 
 
 
111
  # Gradio interface
112
  with gr.Blocks() as demo:
113
+ gr.Markdown("# Chat with your PDF documents")
114
+ with gr.Row():
115
+ file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
116
+ update_button = gr.Button("Update Vector Store")
117
+ use_recursive_splitter = gr.Checkbox(label="Use Recursive Text Splitter", value=False)
118
+ update_output = gr.Textbox(label="Update Status")
119
+ update_button.click(update_vectors, inputs=[file_input, use_recursive_splitter], outputs=update_output)
120
+ with gr.Row():
121
+ question_input = gr.Textbox(label="Ask a question about your documents")
122
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
123
+ top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
124
+ repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
125
+ submit_button = gr.Button("Submit")
126
+ answer_output = gr.Textbox(label="Answer")
127
+ submit_button.click(ask_question, inputs=[question_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=answer_output)
128
+ extract_button = gr.Button("Extract Database to Excel")
129
+ excel_output = gr.File(label="Download Excel File")
130
+ extract_button.click(extract_db_to_excel, inputs=[], outputs=excel_output)
131
+ clear_button = gr.Button("Clear Cache")
132
+ clear_output = gr.Textbox(label="Cache Status")
133
+ clear_button.click(clear_cache, inputs=[], outputs=clear_output)
134
  if __name__ == "__main__":
135
+ demo.launch()