NaimaAqeel commited on
Commit
9ce0b96
·
verified ·
1 Parent(s): 8d35da0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -20
app.py CHANGED
@@ -8,6 +8,7 @@ from nltk.tokenize import sent_tokenize
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from sentence_transformers import SentenceTransformer
10
  import gradio as gr
 
11
 
12
  # Download NLTK punkt tokenizer if not already downloaded
13
  import nltk
@@ -16,8 +17,25 @@ nltk.download('punkt')
16
  # Initialize Sentence Transformer model for embeddings
17
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Initialize FAISS index using LangChain
20
- faiss_index = None # Initialize or load your FAISS index as needed
 
 
 
 
21
 
22
  # Function to extract text from a PDF file
23
  def extract_text_from_pdf(pdf_data):
@@ -45,7 +63,7 @@ def preprocess_text(text):
45
  sentences = sent_tokenize(text)
46
  return sentences
47
 
48
- # Function to handle file uploads
49
  def upload_files(files):
50
  global faiss_index
51
  try:
@@ -82,13 +100,35 @@ def upload_files(files):
82
  print(f"Error processing files: {e}")
83
  return {"error": str(e)} # Provide informative error message
84
 
85
- # Function to process queries
86
  def process_and_query(state, question):
87
  if question:
88
  try:
89
- # Placeholder response based on query processing
90
- response_message = "Placeholder response based on query processing"
91
- return {"message": response_message, "conversation": state}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
  print(f"Error processing query: {e}")
94
  return {"error": str(e)}
@@ -97,21 +137,29 @@ def process_and_query(state, question):
97
 
98
  # Define the Gradio interface
99
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  gr.Interface(
101
- fn=None, # Replace with your function that handles interface logic
102
- inputs=gr.Interface.Layout([
103
- gr.Tab("Upload Files", gr.Interface.Layout([
104
- gr.File(label="Upload PDF or DOCX files", multiple=True),
105
- gr.Button("Upload", onclick=upload_files),
106
- gr.Textbox("Upload Status", default="No file uploaded yet", multiline=True)
107
- ])),
108
- gr.Tab("Query", gr.Interface.Layout([
109
- gr.Textbox("Enter your query", label="Query Input"),
110
- gr.Button("Search", onclick=process_and_query),
111
- gr.Textbox("Query Response", default="No query processed yet", multiline=True)
112
- ]))
113
- ]),
114
- outputs=gr.Textbox("Output", label="Output", default="Output will be shown here", multiline=True),
115
  live=True,
116
  capture_session=True
117
  ).launch()
 
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from sentence_transformers import SentenceTransformer
10
  import gradio as gr
11
+ import torch
12
 
13
  # Download NLTK punkt tokenizer if not already downloaded
14
  import nltk
 
17
  # Initialize Sentence Transformer model for embeddings
18
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
19
 
20
+ # Initialize Hugging Face API token
21
+ api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
22
+ if not api_token:
23
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
24
+
25
+ # Initialize RAG models from Hugging Face
26
+ generator_model_name = "facebook/bart-base"
27
+ retriever_model_name = "facebook/bart-base"
28
+ generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
29
+ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
30
+ retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
31
+ retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
32
+
33
  # Initialize FAISS index using LangChain
34
+ from langchain_community.vectorstores import FAISS
35
+ from langchain_community.embeddings import HuggingFaceEmbeddings
36
+
37
+ hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
38
+ faiss_index = FAISS(embedding_function=hf_embeddings)
39
 
40
  # Function to extract text from a PDF file
41
  def extract_text_from_pdf(pdf_data):
 
63
  sentences = sent_tokenize(text)
64
  return sentences
65
 
66
+ # Function to handle file uploads and update FAISS index
67
  def upload_files(files):
68
  global faiss_index
69
  try:
 
100
  print(f"Error processing files: {e}")
101
  return {"error": str(e)} # Provide informative error message
102
 
103
+ # Function to process queries using RAG model
104
  def process_and_query(state, question):
105
  if question:
106
  try:
107
+ # Search the FAISS index for similar passages
108
+ question_embedding = embedding_model.encode([question])
109
+ D, I = faiss_index.search(np.array(question_embedding), k=5)
110
+ retrieved_passages = [faiss_index.index_to_text(i) for i in I[0]]
111
+
112
+ # Use generator model to generate response based on question and retrieved passages
113
+ prompt_template = """
114
+ Answer the question as detailed as possible from the provided context,
115
+ make sure to provide all the details, if the answer is not in
116
+ provided context just say, "answer is not available in the context",
117
+ don't provide the wrong answer
118
+ Context:\n{context}\n
119
+ Question:\n{question}\n
120
+ Answer:
121
+ """
122
+ combined_input = prompt_template.format(context=' '.join(retrieved_passages), question=question)
123
+ inputs = generator_tokenizer(combined_input, return_tensors="pt")
124
+ with torch.no_grad():
125
+ generator_outputs = generator.generate(**inputs)
126
+ generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
127
+
128
+ # Update conversation history
129
+ state.append({"question": question, "answer": generated_text})
130
+
131
+ return {"message": generated_text, "conversation": state}
132
  except Exception as e:
133
  print(f"Error processing query: {e}")
134
  return {"error": str(e)}
 
137
 
138
  # Define the Gradio interface
139
  def main():
140
+ upload_tab = gr.Interface(
141
+ fn=upload_files,
142
+ inputs=gr.inputs.File(label="Upload PDF or DOCX files", multiple=True),
143
+ outputs=gr.outputs.Text(label="Upload Status", default="No file uploaded yet", type="textbox"),
144
+ live=True,
145
+ capture_session=True
146
+ )
147
+
148
+ query_tab = gr.Interface(
149
+ fn=process_and_query,
150
+ inputs=gr.inputs.Textbox(label="Enter your query"),
151
+ outputs=gr.outputs.Textbox(label="Query Response", default="No query processed yet", type="textbox"),
152
+ live=True,
153
+ capture_session=True
154
+ )
155
+
156
  gr.Interface(
157
+ fn=None,
158
+ inputs=[
159
+ gr.Interface.Tab("Upload Files", upload_tab),
160
+ gr.Interface.Tab("Query", query_tab)
161
+ ],
162
+ outputs=gr.outputs.Textbox(label="Output", default="Output will be shown here", type="textbox"),
 
 
 
 
 
 
 
 
163
  live=True,
164
  capture_session=True
165
  ).launch()