File size: 5,168 Bytes
1649416
03bc240
 
 
90bf4dc
d1c01a2
cb866dd
be68f20
c4f7f00
 
377f3f1
2b77a1d
 
03bc240
 
44b52da
c4f7f00
90bf4dc
c4f7f00
b28c6a7
 
 
 
 
 
 
c4f7f00
d382509
c4f7f00
90bf4dc
b28c6a7
 
 
 
 
 
c4f7f00
d382509
03bc240
90bf4dc
2c02a9e
03bc240
be13366
377f3f1
03bc240
90bf4dc
be68f20
 
377f3f1
be68f20
90bf4dc
03bc240
 
90bf4dc
 
be68f20
 
90bf4dc
c4f7f00
90bf4dc
377f3f1
03bc240
 
 
 
 
 
 
90bf4dc
03bc240
90bf4dc
c4f7f00
be68f20
90bf4dc
be68f20
90bf4dc
03bc240
 
377f3f1
 
 
 
be68f20
 
90bf4dc
 
8c06cc2
90bf4dc
be68f20
 
c4f7f00
be68f20
c4f7f00
eda6735
03bc240
be68f20
 
 
377f3f1
03bc240
 
 
 
 
 
 
be68f20
 
377f3f1
 
be68f20
377f3f1
 
be68f20
 
 
 
c4f7f00
 
 
90bf4dc
 
 
 
c4f7f00
90bf4dc
 
 
 
 
c4f7f00
90bf4dc
 
 
 
03bc240
 
 
eda6735
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import gradio as gr
import fitz  # PyMuPDF for PDF text extraction
from docx import Document  # python-docx for DOCX text extraction
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import sent_tokenize
import torch
import pickle
import nltk

# Download NLTK punkt tokenizer data if not already downloaded
nltk.download('punkt', quiet=True)

# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_path):
    text = ""
    try:
        doc = fitz.open(pdf_path)
        for page_num in range(len(doc)):
            page = doc.load_page(page_num)
            text += page.get_text()
    except Exception as e:
        print(f"Error extracting text from PDF: {e}")
    return text

# Function to extract text from a Word document
def extract_text_from_docx(docx_path):
    text = ""
    try:
        doc = Document(docx_path)
        text = "\n".join([para.text for para in doc.paragraphs])
    except Exception as e:
        print(f"Error extracting text from DOCX: {e}")
    return text

# Initialize the SentenceTransformer model for embeddings
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize the HuggingFaceEmbeddings for LangChain
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

# Initialize the FAISS index
index_path = "faiss_index.pkl"
if os.path.exists(index_path):
    with open(index_path, "rb") as f:
        faiss_index = pickle.load(f)
        print("Loaded FAISS index from faiss_index.pkl")
else:
    # Initialize FAISS index using LangChain
    faiss_index = FAISS(embedding_function=hf_embeddings)

def preprocess_text(text):
    sentences = sent_tokenize(text)
    return sentences

def upload_files(files):
    try:
        for file in files:
            if isinstance(file, str):  # Assuming `file` is a string (file path)
                if file.endswith('.pdf'):
                    text = extract_text_from_pdf(file)
                elif file.endswith('.docx'):
                    text = extract_text_from_docx(file)
                else:
                    return {"error": "Unsupported file format"}
            else:
                return {"error": "Invalid file format: expected a string"}

            # Preprocess text
            sentences = preprocess_text(text)

            # Encode sentences and add to FAISS index
            embeddings = embedding_model.encode(sentences)
            for sentence, embedding in zip(sentences, embeddings):
                faiss_index.add_sentence(sentence, embedding)

        # Save the updated index
        with open(index_path, "wb") as f:
            pickle.dump(faiss_index, f)

        return {"message": "Files processed successfully"}
    except Exception as e:
        print(f"Error processing files: {e}")
        return {"error": str(e)}  # Provide informative error message

def process_and_query(state, files, question):
    if files:
        upload_result = upload_files(files)
        if "error" in upload_result:
            return upload_result

    if question:
        # Preprocess the question
        question_embedding = embedding_model.encode([question])

        # Search the FAISS index for similar passages
        retrieved_results = faiss_index.similarity_search(question, k=5)  # Retrieve top 5 passages
        retrieved_passages = [result['text'] for result in retrieved_results]

        # Initialize RAG generator model
        generator_model_name = "facebook/bart-base"
        generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
        generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)

        # Use generator model to generate response based on question and retrieved passages
        combined_input = question + " ".join(retrieved_passages)
        inputs = generator_tokenizer(combined_input, return_tensors="pt")
        with torch.no_grad():
            generator_outputs = generator.generate(**inputs)
            generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)

        # Update conversation history
        state["conversation"].append({"question": question, "answer": generated_text})

        return {"message": generated_text, "conversation": state["conversation"]}

    return {"error": "No question provided"}

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Document Upload and Query System")

    with gr.Tab("Upload Files"):
        upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
        upload_button = gr.Button("Upload")
        upload_output = gr.Textbox()
        upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)

    with gr.Tab("Query"):
        query = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Search")
        query_output = gr.Textbox()
        query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)

demo.launch()