Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from nltk.tokenize import sent_tokenize
|
|
8 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
9 |
from sentence_transformers import SentenceTransformer
|
10 |
import gradio as gr
|
|
|
11 |
|
12 |
# Download NLTK punkt tokenizer if not already downloaded
|
13 |
import nltk
|
@@ -16,8 +17,25 @@ nltk.download('punkt')
|
|
16 |
# Initialize Sentence Transformer model for embeddings
|
17 |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# Initialize FAISS index using LangChain
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Function to extract text from a PDF file
|
23 |
def extract_text_from_pdf(pdf_data):
|
@@ -45,7 +63,7 @@ def preprocess_text(text):
|
|
45 |
sentences = sent_tokenize(text)
|
46 |
return sentences
|
47 |
|
48 |
-
# Function to handle file uploads
|
49 |
def upload_files(files):
|
50 |
global faiss_index
|
51 |
try:
|
@@ -82,13 +100,35 @@ def upload_files(files):
|
|
82 |
print(f"Error processing files: {e}")
|
83 |
return {"error": str(e)} # Provide informative error message
|
84 |
|
85 |
-
# Function to process queries
|
86 |
def process_and_query(state, question):
|
87 |
if question:
|
88 |
try:
|
89 |
-
#
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
except Exception as e:
|
93 |
print(f"Error processing query: {e}")
|
94 |
return {"error": str(e)}
|
@@ -97,21 +137,29 @@ def process_and_query(state, question):
|
|
97 |
|
98 |
# Define the Gradio interface
|
99 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
gr.Interface(
|
101 |
-
fn=None,
|
102 |
-
inputs=
|
103 |
-
gr.Tab("Upload Files",
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
])),
|
108 |
-
gr.Tab("Query", gr.Interface.Layout([
|
109 |
-
gr.Textbox("Enter your query", label="Query Input"),
|
110 |
-
gr.Button("Search", onclick=process_and_query),
|
111 |
-
gr.Textbox("Query Response", default="No query processed yet", multiline=True)
|
112 |
-
]))
|
113 |
-
]),
|
114 |
-
outputs=gr.Textbox("Output", label="Output", default="Output will be shown here", multiline=True),
|
115 |
live=True,
|
116 |
capture_session=True
|
117 |
).launch()
|
|
|
8 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
9 |
from sentence_transformers import SentenceTransformer
|
10 |
import gradio as gr
|
11 |
+
import torch
|
12 |
|
13 |
# Download NLTK punkt tokenizer if not already downloaded
|
14 |
import nltk
|
|
|
17 |
# Initialize Sentence Transformer model for embeddings
|
18 |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
|
20 |
+
# Initialize Hugging Face API token
|
21 |
+
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
22 |
+
if not api_token:
|
23 |
+
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
|
24 |
+
|
25 |
+
# Initialize RAG models from Hugging Face
|
26 |
+
generator_model_name = "facebook/bart-base"
|
27 |
+
retriever_model_name = "facebook/bart-base"
|
28 |
+
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
|
29 |
+
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
|
30 |
+
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
|
31 |
+
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
|
32 |
+
|
33 |
# Initialize FAISS index using LangChain
|
34 |
+
from langchain_community.vectorstores import FAISS
|
35 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
36 |
+
|
37 |
+
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
|
38 |
+
faiss_index = FAISS(embedding_function=hf_embeddings)
|
39 |
|
40 |
# Function to extract text from a PDF file
|
41 |
def extract_text_from_pdf(pdf_data):
|
|
|
63 |
sentences = sent_tokenize(text)
|
64 |
return sentences
|
65 |
|
66 |
+
# Function to handle file uploads and update FAISS index
|
67 |
def upload_files(files):
|
68 |
global faiss_index
|
69 |
try:
|
|
|
100 |
print(f"Error processing files: {e}")
|
101 |
return {"error": str(e)} # Provide informative error message
|
102 |
|
103 |
+
# Function to process queries using RAG model
|
104 |
def process_and_query(state, question):
|
105 |
if question:
|
106 |
try:
|
107 |
+
# Search the FAISS index for similar passages
|
108 |
+
question_embedding = embedding_model.encode([question])
|
109 |
+
D, I = faiss_index.search(np.array(question_embedding), k=5)
|
110 |
+
retrieved_passages = [faiss_index.index_to_text(i) for i in I[0]]
|
111 |
+
|
112 |
+
# Use generator model to generate response based on question and retrieved passages
|
113 |
+
prompt_template = """
|
114 |
+
Answer the question as detailed as possible from the provided context,
|
115 |
+
make sure to provide all the details, if the answer is not in
|
116 |
+
provided context just say, "answer is not available in the context",
|
117 |
+
don't provide the wrong answer
|
118 |
+
Context:\n{context}\n
|
119 |
+
Question:\n{question}\n
|
120 |
+
Answer:
|
121 |
+
"""
|
122 |
+
combined_input = prompt_template.format(context=' '.join(retrieved_passages), question=question)
|
123 |
+
inputs = generator_tokenizer(combined_input, return_tensors="pt")
|
124 |
+
with torch.no_grad():
|
125 |
+
generator_outputs = generator.generate(**inputs)
|
126 |
+
generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
|
127 |
+
|
128 |
+
# Update conversation history
|
129 |
+
state.append({"question": question, "answer": generated_text})
|
130 |
+
|
131 |
+
return {"message": generated_text, "conversation": state}
|
132 |
except Exception as e:
|
133 |
print(f"Error processing query: {e}")
|
134 |
return {"error": str(e)}
|
|
|
137 |
|
138 |
# Define the Gradio interface
|
139 |
def main():
|
140 |
+
upload_tab = gr.Interface(
|
141 |
+
fn=upload_files,
|
142 |
+
inputs=gr.inputs.File(label="Upload PDF or DOCX files", multiple=True),
|
143 |
+
outputs=gr.outputs.Text(label="Upload Status", default="No file uploaded yet", type="textbox"),
|
144 |
+
live=True,
|
145 |
+
capture_session=True
|
146 |
+
)
|
147 |
+
|
148 |
+
query_tab = gr.Interface(
|
149 |
+
fn=process_and_query,
|
150 |
+
inputs=gr.inputs.Textbox(label="Enter your query"),
|
151 |
+
outputs=gr.outputs.Textbox(label="Query Response", default="No query processed yet", type="textbox"),
|
152 |
+
live=True,
|
153 |
+
capture_session=True
|
154 |
+
)
|
155 |
+
|
156 |
gr.Interface(
|
157 |
+
fn=None,
|
158 |
+
inputs=[
|
159 |
+
gr.Interface.Tab("Upload Files", upload_tab),
|
160 |
+
gr.Interface.Tab("Query", query_tab)
|
161 |
+
],
|
162 |
+
outputs=gr.outputs.Textbox(label="Output", default="Output will be shown here", type="textbox"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
live=True,
|
164 |
capture_session=True
|
165 |
).launch()
|