NaimaAqeel commited on
Commit
377f3f1
·
verified ·
1 Parent(s): 7fc8bcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -2,11 +2,13 @@ import os
2
  import fitz # PyMuPDF
3
  from docx import Document
4
  from sentence_transformers import SentenceTransformer
5
- from langchain_community.vectorstores import FAISS
 
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from nltk.tokenize import sent_tokenize
8
  import torch
9
  import gradio as gr
 
10
 
11
  # Function to extract text from a PDF file
12
  def extract_text_from_pdf(pdf_path):
@@ -38,31 +40,31 @@ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
38
  retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
39
  retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
40
 
 
 
 
 
41
  # Load or create FAISS index
42
  index_path = "faiss_index.pkl"
43
  if os.path.exists(index_path):
44
  with open(index_path, "rb") as f:
45
- index = FAISS.load(f)
46
  print("Loaded FAISS index from faiss_index.pkl")
47
  else:
48
- # Create a new FAISS index if it doesn't exist
49
- index = FAISS(embedding_dimension=embedding_model.get_sentence_embedding_dimension())
50
- with open(index_path, "wb") as f:
51
- FAISS.save(index, f)
52
- print("Created new FAISS index and saved to faiss_index.pkl")
53
 
54
  def preprocess_text(text):
55
  sentences = sent_tokenize(text)
56
  return sentences
57
 
58
  def upload_files(files):
59
- global index
60
  try:
61
- for file_path in files:
62
- if file_path.endswith('.pdf'):
63
- text = extract_text_from_pdf(file_path)
64
- elif file_path.endswith('.docx'):
65
- text = extract_text_from_docx(file_path)
66
  else:
67
  return {"error": "Unsupported file format"}
68
 
@@ -71,7 +73,11 @@ def upload_files(files):
71
 
72
  # Encode sentences and add to FAISS index
73
  embeddings = embedding_model.encode(sentences)
74
- index.add(embeddings)
 
 
 
 
75
 
76
  return {"message": "Files processed successfully"}
77
  except Exception as e:
@@ -88,22 +94,18 @@ def process_and_query(state, files, question):
88
  # Preprocess the question
89
  question_embedding = embedding_model.encode([question])
90
 
91
- # Use retriever model to retrieve relevant passages
92
- with torch.no_grad():
93
- retriever_outputs = retriever(**retriever_tokenizer(question, return_tensors="pt"))
94
- retriever_hidden_states = retriever_outputs.hidden_states[-1] # Last hidden state
95
-
96
- # Search the FAISS index for similar passages based on retrieved hidden states
97
- distances, retrieved_ids = index.search(retriever_hidden_states.cpu().numpy(), k=5) # Retrieve top 5 passages
98
 
99
  # Get the retrieved passages from the document text
100
  retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
101
 
102
  # Use generator model to generate response based on question and retrieved passages
103
- combined_input = torch.cat([question_embedding, embedding_model.encode(retrieved_passages)], dim=0)
 
104
  with torch.no_grad():
105
- generator_outputs = generator(**generator_tokenizer(combined_input, return_tensors="pt"))
106
- generated_text = generator_tokenizer.decode(generator_outputs.sequences.squeeze())
107
 
108
  # Update conversation history
109
  state["conversation"].append({"question": question, "answer": generated_text})
@@ -131,3 +133,4 @@ with gr.Blocks() as demo:
131
  demo.launch()
132
 
133
 
 
 
2
  import fitz # PyMuPDF
3
  from docx import Document
4
  from sentence_transformers import SentenceTransformer
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
  from nltk.tokenize import sent_tokenize
9
  import torch
10
  import gradio as gr
11
+ import pickle
12
 
13
  # Function to extract text from a PDF file
14
  def extract_text_from_pdf(pdf_path):
 
40
  retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
41
  retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
42
 
43
+ # Initialize FAISS index using LangChain
44
+ embedding_dimension = embedding_model.get_sentence_embedding_dimension()
45
+ faiss_index = FAISS(HuggingFaceEmbeddings(embedding_model), dimension=embedding_dimension)
46
+
47
  # Load or create FAISS index
48
  index_path = "faiss_index.pkl"
49
  if os.path.exists(index_path):
50
  with open(index_path, "rb") as f:
51
+ faiss_index = pickle.load(f)
52
  print("Loaded FAISS index from faiss_index.pkl")
53
  else:
54
+ print("Created new FAISS index")
 
 
 
 
55
 
56
  def preprocess_text(text):
57
  sentences = sent_tokenize(text)
58
  return sentences
59
 
60
  def upload_files(files):
61
+ global faiss_index
62
  try:
63
+ for file in files:
64
+ if file.name.endswith('.pdf'):
65
+ text = extract_text_from_pdf(file.name)
66
+ elif file.name.endswith('.docx'):
67
+ text = extract_text_from_docx(file.name)
68
  else:
69
  return {"error": "Unsupported file format"}
70
 
 
73
 
74
  # Encode sentences and add to FAISS index
75
  embeddings = embedding_model.encode(sentences)
76
+ faiss_index.add_texts(sentences, embeddings)
77
+
78
+ # Save the updated index
79
+ with open(index_path, "wb") as f:
80
+ pickle.dump(faiss_index, f)
81
 
82
  return {"message": "Files processed successfully"}
83
  except Exception as e:
 
94
  # Preprocess the question
95
  question_embedding = embedding_model.encode([question])
96
 
97
+ # Search the FAISS index for similar passages
98
+ distances, retrieved_ids = faiss_index.similarity_search_with_score(question_embedding, k=5) # Retrieve top 5 passages
 
 
 
 
 
99
 
100
  # Get the retrieved passages from the document text
101
  retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
102
 
103
  # Use generator model to generate response based on question and retrieved passages
104
+ combined_input = question + " ".join(retrieved_passages)
105
+ inputs = generator_tokenizer(combined_input, return_tensors="pt")
106
  with torch.no_grad():
107
+ generator_outputs = generator.generate(**inputs)
108
+ generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
109
 
110
  # Update conversation history
111
  state["conversation"].append({"question": question, "answer": generated_text})
 
133
  demo.launch()
134
 
135
 
136
+