NaimaAqeel commited on
Commit
3a0b46d
·
verified ·
1 Parent(s): 85a9507

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -32
app.py CHANGED
@@ -10,8 +10,6 @@ from typing import List
10
  from langchain_community.llms import HuggingFaceEndpoint
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
- from nltk.tokenize import sent_tokenize # Import for sentence segmentation
14
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
15
 
16
  # Function to extract text from a PDF file
17
  def extract_text_from_pdf(pdf_path):
@@ -25,9 +23,8 @@ def extract_text_from_pdf(pdf_path):
25
  print(f"Error extracting text from PDF: {e}")
26
  return text
27
 
28
- # Function to extract text from a Word document
29
  def extract_text_from_docx(docx_path):
30
- """Extracts text from a Word document."""
31
  text = ""
32
  try:
33
  doc = Document(docx_path)
@@ -36,32 +33,29 @@ def extract_text_from_docx(docx_path):
36
  print(f"Error extracting text from DOCX: {e}")
37
  return text
38
 
39
-
40
- # Initialize the embedding model
41
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
42
 
43
-
44
- # Hugging Face API token
45
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
46
  if not api_token:
47
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
48
 
 
 
 
 
 
49
 
50
- # Define RAG models
51
- generator_model_name = "facebook/bart-base"
52
- retriever_model_name = "facebook/bart-base"
53
-
54
- generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
55
- generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
56
-
57
- retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
58
- retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
59
-
60
 
61
  # Load or create FAISS index
62
  index_path = "faiss_index.pkl"
63
  document_texts_path = "document_texts.pkl"
 
64
  document_texts = []
 
65
  if os.path.exists(index_path) and os.path.exists(document_texts_path):
66
  try:
67
  with open(index_path, "rb") as f:
@@ -79,16 +73,11 @@ else:
79
  pickle.dump(index, f)
80
  print("Created new FAISS index and saved to faiss_index.pkl")
81
 
82
-
83
- def preprocess_text(text):
84
- sentences = sent_tokenize(text)
85
- return sentences
86
-
87
-
88
  def upload_files(files):
89
  global index, document_texts
90
  try:
91
- for file_path in files:
 
92
  if file_path.endswith('.pdf'):
93
  text = extract_text_from_pdf(file_path)
94
  elif file_path.endswith('.docx'):
@@ -96,18 +85,52 @@ def upload_files(files):
96
  else:
97
  return "Unsupported file format"
98
 
99
- # Preprocess text (call the new function)
100
- sentences = preprocess_text(text)
101
-
102
- # Encode sentences and add to FAISS index
103
  embeddings = embedding_model.encode(sentences)
104
  index.add(np.array(embeddings))
105
-
106
- # Save the updated index and documents
107
-
 
 
 
 
 
 
 
108
  return "Files processed successfully"
109
  except Exception as e:
110
  print(f"Error processing files: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
 
 
10
  from langchain_community.llms import HuggingFaceEndpoint
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
13
 
14
  # Function to extract text from a PDF file
15
  def extract_text_from_pdf(pdf_path):
 
23
  print(f"Error extracting text from PDF: {e}")
24
  return text
25
 
26
+ # Function to extract text from a Word document
27
  def extract_text_from_docx(docx_path):
 
28
  text = ""
29
  try:
30
  doc = Document(docx_path)
 
33
  print(f"Error extracting text from DOCX: {e}")
34
  return text
35
 
36
+ # Initialize the embedding model
 
37
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
38
 
39
+ # Hugging Face API token
 
40
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
41
  if not api_token:
42
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
43
 
44
+ # Initialize the HuggingFace LLM
45
+ llm = HuggingFaceEndpoint(
46
+ endpoint_url="https://api-inference.huggingface.co/models/gpt2",
47
+ model_kwargs={"api_key": api_token}
48
+ )
49
 
50
+ # Initialize the HuggingFace embeddings
51
+ embedding = HuggingFaceEmbeddings()
 
 
 
 
 
 
 
 
52
 
53
  # Load or create FAISS index
54
  index_path = "faiss_index.pkl"
55
  document_texts_path = "document_texts.pkl"
56
+
57
  document_texts = []
58
+
59
  if os.path.exists(index_path) and os.path.exists(document_texts_path):
60
  try:
61
  with open(index_path, "rb") as f:
 
73
  pickle.dump(index, f)
74
  print("Created new FAISS index and saved to faiss_index.pkl")
75
 
 
 
 
 
 
 
76
  def upload_files(files):
77
  global index, document_texts
78
  try:
79
+ for file in files:
80
+ file_path = file.name # Get the file path from the NamedString object
81
  if file_path.endswith('.pdf'):
82
  text = extract_text_from_pdf(file_path)
83
  elif file_path.endswith('.docx'):
 
85
  else:
86
  return "Unsupported file format"
87
 
88
+ # Process the text and update FAISS index
89
+ sentences = text.split("\n")
 
 
90
  embeddings = embedding_model.encode(sentences)
91
  index.add(np.array(embeddings))
92
+ document_texts.append(text)
93
+
94
+ # Save the updated index and documents
95
+ with open(index_path, "wb") as f:
96
+ pickle.dump(index, f)
97
+ print("Saved updated FAISS index to faiss_index.pkl")
98
+ with open(document_texts_path, "wb") as f:
99
+ pickle.dump(document_texts, f)
100
+ print("Saved updated document texts to document_texts.pkl")
101
+
102
  return "Files processed successfully"
103
  except Exception as e:
104
  print(f"Error processing files: {e}")
105
+ return f"Error processing files: {e}"
106
+
107
+ def query_text(text):
108
+ try:
109
+ # Encode the query text
110
+ query_embedding = embedding_model.encode([text])
111
+
112
+ # Search the FAISS index
113
+ D, I = index.search(np.array(query_embedding), k=5)
114
+
115
+ top_documents = []
116
+ for idx in I[0]:
117
+ if idx != -1 and idx < len(document_texts): # Ensure that a valid index is found
118
+ top_documents.append(document_texts[idx])
119
+ else:
120
+ print(f"Invalid index found: {idx}")
121
+ return top_documents
122
+ except Exception as e:
123
+ print(f"Error querying text: {e}")
124
+ return f"Error querying text: {e}"
125
+
126
+ # Create Gradio interface
127
+ with gr.Blocks() as demo:
128
+ gr.Markdown("## Document Upload and Query System")
129
+
130
+ with gr.Tab("Upload Files"):
131
+ upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
132
+ upload_button = gr.Button("Upload")
133
+ upload_output = gr
134
 
135
 
136