File size: 5,026 Bytes
1649416
c4f7f00
90bf4dc
 
 
be68f20
c4f7f00
 
 
eda6735
c4f7f00
90bf4dc
c4f7f00
 
 
 
 
d382509
c4f7f00
90bf4dc
c4f7f00
 
 
d382509
c4f7f00
90bf4dc
2c02a9e
c4f7f00
90bf4dc
 
be68f20
eda6735
c4f7f00
be68f20
 
 
 
 
 
90bf4dc
c4f7f00
90bf4dc
be68f20
 
 
 
90bf4dc
 
be68f20
90bf4dc
be68f20
90bf4dc
 
 
be68f20
 
90bf4dc
c4f7f00
be68f20
90bf4dc
be68f20
90bf4dc
 
 
 
 
be68f20
90bf4dc
c4f7f00
be68f20
90bf4dc
be68f20
90bf4dc
be68f20
 
 
90bf4dc
 
c4f7f00
90bf4dc
be68f20
 
c4f7f00
be68f20
c4f7f00
eda6735
c4f7f00
be68f20
 
 
 
c4f7f00
be68f20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4f7f00
 
 
90bf4dc
 
 
 
c4f7f00
90bf4dc
 
 
 
 
c4f7f00
90bf4dc
 
 
 
c4f7f00
eda6735
90bf4dc
eda6735
c4f7f00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import fitz  # PyMuPDF
from docx import Document
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import sent_tokenize
import torch
import gradio as gr

# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_path):
    text = ""
    doc = fitz.open(pdf_path)
    for page in doc:
        text += page.get_text()
    return text

# Function to extract text from a Word document
def extract_text_from_docx(docx_path):
    doc = Document(docx_path)
    text = "\n".join([para.text for para in doc.paragraphs])
    return text

# Initialize the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")

# Define RAG models
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base"  # Can be the same as generator
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)

# Load or create FAISS index
index_path = "faiss_index.pkl"
if os.path.exists(index_path):
    with open(index_path, "rb") as f:
        index = FAISS.load(f)
        print("Loaded FAISS index from faiss_index.pkl")
else:
    # Create a new FAISS index if it doesn't exist
    index = FAISS(embedding_dimension=embedding_model.get_sentence_embedding_dimension())
    with open(index_path, "wb") as f:
        FAISS.save(index, f)
        print("Created new FAISS index and saved to faiss_index.pkl")

def preprocess_text(text):
    sentences = sent_tokenize(text)
    return sentences

def upload_files(files):
    global index
    try:
        for file_path in files:
            if file_path.endswith('.pdf'):
                text = extract_text_from_pdf(file_path)
            elif file_path.endswith('.docx'):
                text = extract_text_from_docx(file_path)
            else:
                return {"error": "Unsupported file format"}

            # Preprocess text
            sentences = preprocess_text(text)

            # Encode sentences and add to FAISS index
            embeddings = embedding_model.encode(sentences)
            index.add(embeddings)

        return {"message": "Files processed successfully"}
    except Exception as e:
        print(f"Error processing files: {e}")
        return {"error": "Error processing files"}

def process_and_query(state, files, question):
    if files:
        upload_result = upload_files(files)
        if "error" in upload_result:
            return upload_result

    if question:
        # Preprocess the question
        question_embedding = embedding_model.encode([question])

        # Use retriever model to retrieve relevant passages
        with torch.no_grad():
            retriever_outputs = retriever(**retriever_tokenizer(question, return_tensors="pt"))
            retriever_hidden_states = retriever_outputs.hidden_states[-1]  # Last hidden state

        # Search the FAISS index for similar passages based on retrieved hidden states
        distances, retrieved_ids = index.search(retriever_hidden_states.cpu().numpy(), k=5)  # Retrieve top 5 passages

        # Get the retrieved passages from the document text
        retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]

        # Use generator model to generate response based on question and retrieved passages
        combined_input = torch.cat([question_embedding, embedding_model.encode(retrieved_passages)], dim=0)
        with torch.no_grad():
            generator_outputs = generator(**generator_tokenizer(combined_input, return_tensors="pt"))
            generated_text = generator_tokenizer.decode(generator_outputs.sequences.squeeze())

        # Update conversation history
        state["conversation"].append({"question": question, "answer": generated_text})

        return {"message": generated_text, "conversation": state["conversation"]}

    return {"error": "No question provided"}

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Document Upload and Query System")

    with gr.Tab("Upload Files"):
        upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
        upload_button = gr.Button("Upload")
        upload_output = gr.Textbox()
        upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)

    with gr.Tab("Query"):
        query = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Search")
        query_output = gr.Textbox()
        query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)

demo.launch()