Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,11 +2,13 @@ import os
|
|
2 |
import fitz # PyMuPDF
|
3 |
from docx import Document
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
-
from
|
|
|
6 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
7 |
from nltk.tokenize import sent_tokenize
|
8 |
import torch
|
9 |
import gradio as gr
|
|
|
10 |
|
11 |
# Function to extract text from a PDF file
|
12 |
def extract_text_from_pdf(pdf_path):
|
@@ -38,31 +40,31 @@ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
|
|
38 |
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
|
39 |
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
|
40 |
|
|
|
|
|
|
|
|
|
41 |
# Load or create FAISS index
|
42 |
index_path = "faiss_index.pkl"
|
43 |
if os.path.exists(index_path):
|
44 |
with open(index_path, "rb") as f:
|
45 |
-
|
46 |
print("Loaded FAISS index from faiss_index.pkl")
|
47 |
else:
|
48 |
-
|
49 |
-
index = FAISS(embedding_dimension=embedding_model.get_sentence_embedding_dimension())
|
50 |
-
with open(index_path, "wb") as f:
|
51 |
-
FAISS.save(index, f)
|
52 |
-
print("Created new FAISS index and saved to faiss_index.pkl")
|
53 |
|
54 |
def preprocess_text(text):
|
55 |
sentences = sent_tokenize(text)
|
56 |
return sentences
|
57 |
|
58 |
def upload_files(files):
|
59 |
-
global
|
60 |
try:
|
61 |
-
for
|
62 |
-
if
|
63 |
-
text = extract_text_from_pdf(
|
64 |
-
elif
|
65 |
-
text = extract_text_from_docx(
|
66 |
else:
|
67 |
return {"error": "Unsupported file format"}
|
68 |
|
@@ -71,7 +73,11 @@ def upload_files(files):
|
|
71 |
|
72 |
# Encode sentences and add to FAISS index
|
73 |
embeddings = embedding_model.encode(sentences)
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
|
76 |
return {"message": "Files processed successfully"}
|
77 |
except Exception as e:
|
@@ -88,22 +94,18 @@ def process_and_query(state, files, question):
|
|
88 |
# Preprocess the question
|
89 |
question_embedding = embedding_model.encode([question])
|
90 |
|
91 |
-
#
|
92 |
-
|
93 |
-
retriever_outputs = retriever(**retriever_tokenizer(question, return_tensors="pt"))
|
94 |
-
retriever_hidden_states = retriever_outputs.hidden_states[-1] # Last hidden state
|
95 |
-
|
96 |
-
# Search the FAISS index for similar passages based on retrieved hidden states
|
97 |
-
distances, retrieved_ids = index.search(retriever_hidden_states.cpu().numpy(), k=5) # Retrieve top 5 passages
|
98 |
|
99 |
# Get the retrieved passages from the document text
|
100 |
retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
|
101 |
|
102 |
# Use generator model to generate response based on question and retrieved passages
|
103 |
-
combined_input =
|
|
|
104 |
with torch.no_grad():
|
105 |
-
generator_outputs = generator(**
|
106 |
-
generated_text = generator_tokenizer.decode(generator_outputs
|
107 |
|
108 |
# Update conversation history
|
109 |
state["conversation"].append({"question": question, "answer": generated_text})
|
@@ -131,3 +133,4 @@ with gr.Blocks() as demo:
|
|
131 |
demo.launch()
|
132 |
|
133 |
|
|
|
|
2 |
import fitz # PyMuPDF
|
3 |
from docx import Document
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
+
from langchain.vectorstores import FAISS
|
6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
8 |
from nltk.tokenize import sent_tokenize
|
9 |
import torch
|
10 |
import gradio as gr
|
11 |
+
import pickle
|
12 |
|
13 |
# Function to extract text from a PDF file
|
14 |
def extract_text_from_pdf(pdf_path):
|
|
|
40 |
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
|
41 |
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
|
42 |
|
43 |
+
# Initialize FAISS index using LangChain
|
44 |
+
embedding_dimension = embedding_model.get_sentence_embedding_dimension()
|
45 |
+
faiss_index = FAISS(HuggingFaceEmbeddings(embedding_model), dimension=embedding_dimension)
|
46 |
+
|
47 |
# Load or create FAISS index
|
48 |
index_path = "faiss_index.pkl"
|
49 |
if os.path.exists(index_path):
|
50 |
with open(index_path, "rb") as f:
|
51 |
+
faiss_index = pickle.load(f)
|
52 |
print("Loaded FAISS index from faiss_index.pkl")
|
53 |
else:
|
54 |
+
print("Created new FAISS index")
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def preprocess_text(text):
|
57 |
sentences = sent_tokenize(text)
|
58 |
return sentences
|
59 |
|
60 |
def upload_files(files):
|
61 |
+
global faiss_index
|
62 |
try:
|
63 |
+
for file in files:
|
64 |
+
if file.name.endswith('.pdf'):
|
65 |
+
text = extract_text_from_pdf(file.name)
|
66 |
+
elif file.name.endswith('.docx'):
|
67 |
+
text = extract_text_from_docx(file.name)
|
68 |
else:
|
69 |
return {"error": "Unsupported file format"}
|
70 |
|
|
|
73 |
|
74 |
# Encode sentences and add to FAISS index
|
75 |
embeddings = embedding_model.encode(sentences)
|
76 |
+
faiss_index.add_texts(sentences, embeddings)
|
77 |
+
|
78 |
+
# Save the updated index
|
79 |
+
with open(index_path, "wb") as f:
|
80 |
+
pickle.dump(faiss_index, f)
|
81 |
|
82 |
return {"message": "Files processed successfully"}
|
83 |
except Exception as e:
|
|
|
94 |
# Preprocess the question
|
95 |
question_embedding = embedding_model.encode([question])
|
96 |
|
97 |
+
# Search the FAISS index for similar passages
|
98 |
+
distances, retrieved_ids = faiss_index.similarity_search_with_score(question_embedding, k=5) # Retrieve top 5 passages
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
# Get the retrieved passages from the document text
|
101 |
retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
|
102 |
|
103 |
# Use generator model to generate response based on question and retrieved passages
|
104 |
+
combined_input = question + " ".join(retrieved_passages)
|
105 |
+
inputs = generator_tokenizer(combined_input, return_tensors="pt")
|
106 |
with torch.no_grad():
|
107 |
+
generator_outputs = generator.generate(**inputs)
|
108 |
+
generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
|
109 |
|
110 |
# Update conversation history
|
111 |
state["conversation"].append({"question": question, "answer": generated_text})
|
|
|
133 |
demo.launch()
|
134 |
|
135 |
|
136 |
+
|