File size: 5,613 Bytes
1649416
834c71a
 
98c11b9
834c71a
80e4cb4
3ee65d3
 
90bf4dc
80e4cb4
 
834c71a
377f3f1
56ec544
3ee65d3
 
834c71a
 
 
e4b3db1
834c71a
 
e4b3db1
834c71a
 
 
 
 
 
 
e4b3db1
834c71a
 
e4b3db1
834c71a
 
 
 
2b77a1d
3ee65d3
80e4cb4
 
3ee65d3
80e4cb4
 
 
 
3ee65d3
80e4cb4
3ee65d3
80e4cb4
 
 
 
 
 
 
 
 
834c71a
80e4cb4
834c71a
 
 
80e4cb4
834c71a
d382509
90bf4dc
be68f20
 
90bf4dc
7adb197
834c71a
90bf4dc
7adb197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80e4cb4
124e62a
834c71a
 
be68f20
7adb197
90bf4dc
7adb197
834c71a
eda6735
7adb197
834c71a
 
 
be68f20
 
834c71a
 
 
ec0cc7d
834c71a
e510bfe
 
 
 
 
 
834c71a
 
e510bfe
 
834c71a
 
ec0cc7d
 
 
 
 
6927a5e
ec0cc7d
6927a5e
c4f7f00
834c71a
9bf056e
4a20efe
 
 
90bf4dc
9bf056e
 
 
 
 
 
 
 
 
 
 
 
 
d7e53b2
caf7f08
4953693
9bf056e
 
caf7f08
4953693
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import faiss
import numpy as np
import PyPDF2
import io
from docx import Document
from nltk.tokenize import sent_tokenize
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import gradio as gr
import pickle

# Download NLTK punkt tokenizer if not already downloaded
import nltk
nltk.download('punkt')

# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_data):
    text = ""
    try:
        pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
        for page in pdf_reader.pages:
            text += page.extract_text()
    except Exception as e:
        print(f"Error extracting text from PDF: {e}")
    return text

# Function to extract text from a Word document
def extract_text_from_docx(docx_data):
    text = ""
    try:
        doc = Document(io.BytesIO(docx_data))
        text = "\n".join([para.text for para in doc.paragraphs])
    except Exception as e:
        print(f"Error extracting text from DOCX: {e}")
    return text

# Initialize Sentence Transformer model for embeddings
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")

# Initialize RAG models from Hugging Face
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base"
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)

# Initialize FAISS index using LangChain
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

# Load or create FAISS index
index_path = "faiss_index.pkl"
if os.path.exists(index_path):
    with open(index_path, "rb") as f:
        faiss_index = pickle.load(f)
        print("Loaded FAISS index from faiss_index.pkl")
else:
    faiss_index = FAISS(embedding_function=hf_embeddings)

def preprocess_text(text):
    sentences = sent_tokenize(text)
    return sentences

def upload_files(files):
    global faiss_index
    try:
        for file in files:
            if file.name.endswith('.pdf'):
                text = extract_text_from_pdf(file.read())
            elif file.name.endswith('.docx'):
                text = extract_text_from_docx(file.read())
            else:
                return {"error": "Unsupported file format"}

            # Preprocess text
            sentences = preprocess_text(text)

            # Encode sentences and add to FAISS index
            embeddings = embedding_model.encode(sentences)
            for embedding in embeddings:
                faiss_index.add(np.expand_dims(embedding, axis=0))

        # Save the updated index
        with open(index_path, "wb") as f:
            pickle.dump(faiss_index, f)

        return {"message": "Files processed successfully"}
    except Exception as e:
        print(f"Error processing files: {e}")
        return {"error": str(e)}  # Provide informative error message


def process_and_query(state, question):
    if question:
        # Preprocess the question
        question_embedding = embedding_model.encode([question])

        # Search the FAISS index for similar passages
        D, I = faiss_index.search(np.array(question_embedding), k=5)
        retrieved_passages = [faiss_index.index_to_text(i) for i in I[0]]

        # Use generator model to generate response based on question and retrieved passages
        prompt_template = """
        Answer the question as detailed as possible from the provided context,
        make sure to provide all the details, if the answer is not in
        provided context just say, "answer is not available in the context",
        don't provide the wrong answer

        Context:\n{context}\n
        Question:\n{question}\n
        Answer:
        """
        combined_input = prompt_template.format(context=' '.join(retrieved_passages), question=question)
        inputs = generator_tokenizer(combined_input, return_tensors="pt")
        with torch.no_grad():
            generator_outputs = generator.generate(**inputs)
            generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)

        # Update conversation history
        state.append({"question": question, "answer": generated_text})

        return {"message": generated_text, "conversation": state}

    return {"error": "No question provided"}
    
# Initialize an empty state variable to store conversation history
state = []

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Document Upload and Query System")
    
    with gr.Tab("Upload Files"):
        upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
        upload_button = gr.Button("Upload")
        upload_output = gr.Textbox()
        upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)
    
    with gr.Tab("Query"):
        query = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Search")
        query_output = gr.Textbox()
        
        # Setup the click event with correct inputs and outputs
        query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)

demo.launch()