File size: 4,828 Bytes
1649416
03bc240
80e4cb4
03bc240
90bf4dc
80e4cb4
 
be68f20
c4f7f00
 
377f3f1
56ec544
124e62a
56ec544
 
 
 
 
 
2b77a1d
80e4cb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124e62a
 
 
 
 
44b52da
90bf4dc
c4f7f00
b28c6a7
 
 
 
 
 
80e4cb4
c4f7f00
d382509
90bf4dc
b28c6a7
 
 
 
 
80e4cb4
c4f7f00
d382509
90bf4dc
be68f20
 
90bf4dc
c4f7f00
90bf4dc
80e4cb4
 
377f3f1
80e4cb4
 
 
 
 
 
 
 
 
 
 
 
124e62a
 
80e4cb4
 
 
 
 
124e62a
 
be68f20
 
80e4cb4
90bf4dc
80e4cb4
 
a37ef5b
be68f20
 
c4f7f00
be68f20
c4f7f00
eda6735
03bc240
be68f20
 
80e4cb4
c4f7f00
 
90bf4dc
 
 
 
c4f7f00
90bf4dc
 
 
 
 
c4f7f00
90bf4dc
 
 
 
03bc240
 
 
56ec544
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import gradio as gr
from docx import Document
import fitz  # PyMuPDF for PDF text extraction
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import sent_tokenize
import torch
import pickle
import nltk
import faiss

# Ensure NLTK resources are downloaded
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Initialize the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")

# Define RAG models
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base"  # Can be the same as generator
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)

# Initialize FAISS index using LangChain
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

# Load or create FAISS index
index_path = "faiss_index.pkl"
if os.path.exists(index_path):
    with open(index_path, "rb") as f:
        faiss_index = pickle.load(f)
        print("Loaded FAISS index from faiss_index.pkl")
else:
    # Create a new FAISS index
    d = embedding_model.get_sentence_embedding_dimension()  # Dimension of the embeddings
    nlist = 100  # Number of clusters (for IVF)
    quantizer = faiss.IndexFlatL2(d)  # This is the quantizer for IVF
    faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)

def extract_text_from_pdf(pdf_path):
    text = ""
    try:
        doc = fitz.open(pdf_path)
        for page_num in range(len(doc)):
            page = doc.load_page(page_num)
            text += page.get_text()
    except Exception as e:
        raise RuntimeError(f"Error extracting text from PDF '{pdf_path}': {e}")
    return text

def extract_text_from_docx(docx_path):
    text = ""
    try:
        doc = Document(docx_path)
        text = "\n".join([para.text for para in doc.paragraphs])
    except Exception as e:
        raise RuntimeError(f"Error extracting text from DOCX '{docx_path}': {e}")
    return text

def preprocess_text(text):
    sentences = sent_tokenize(text)
    return sentences

def upload_files(files):
    try:
        global faiss_index
        
        for file in files:
            try:
                file_path = file.name
                if file_path.endswith('.pdf'):
                    text = extract_text_from_pdf(file_path)
                elif file_path.endswith('.docx'):
                    text = extract_text_from_docx(file_path)
                else:
                    return {"error": f"Unsupported file format: {file_path}"}

                sentences = preprocess_text(text)
                embeddings = embedding_model.encode(sentences)
                
                for embedding in embeddings:
                    faiss_index.add(np.array([embedding]))  # Add each embedding individually

            except Exception as e:
                print(f"Error processing file '{file.name}': {e}")
                return {"error": str(e)}

        # Save the updated index
        faiss.write_index(faiss_index, index_path)

        return {"message": "Files processed successfully"}
    
    except Exception as e:
        print(f"General error processing files: {e}")
        return {"error": str(e)}

def process_and_query(state, files, question):
    if files:
        upload_result = upload_files(files)
        if "error" in upload_result:
            return upload_result

    if question:
        question_embedding = embedding_model.encode([question])

        # Perform FAISS search and generate response as before

    return {"error": "No question provided"}

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Document Upload and Query System")

    with gr.Tab("Upload Files"):
        upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
        upload_button = gr.Button("Upload")
        upload_output = gr.Textbox()
        upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)

    with gr.Tab("Query"):
        query = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Search")
        query_output = gr.Textbox()
        query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)

demo.launch()