File size: 8,017 Bytes
1649416
03bc240
80e4cb4
03bc240
90bf4dc
80e4cb4
 
be68f20
c4f7f00
 
377f3f1
56ec544
124e62a
8c85ad8
56ec544
 
 
 
 
 
2b77a1d
80e4cb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec0cc7d
80e4cb4
ec0cc7d
 
80e4cb4
124e62a
 
ec0cc7d
44b52da
d1c1284
 
 
 
 
90bf4dc
c4f7f00
b28c6a7
 
 
 
 
 
80e4cb4
c4f7f00
d382509
90bf4dc
b28c6a7
 
 
 
 
80e4cb4
c4f7f00
d382509
90bf4dc
be68f20
 
90bf4dc
c4f7f00
d1c1284
90bf4dc
377f3f1
80e4cb4
b3c97dd
 
 
 
 
80e4cb4
 
 
 
 
 
 
 
 
 
ec0cc7d
d1c1284
80e4cb4
 
b3c97dd
80e4cb4
 
124e62a
 
be68f20
 
80e4cb4
90bf4dc
80e4cb4
 
a37ef5b
e510bfe
 
4d187c1
 
eda6735
4d187c1
be68f20
 
ec0cc7d
 
7000461
ec0cc7d
 
80cb300
e510bfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80cb300
ec0cc7d
 
 
 
 
 
 
 
c4f7f00
4d187c1
 
 
90bf4dc
 
 
 
c4f7f00
90bf4dc
 
 
 
d1c1284
c4f7f00
90bf4dc
 
 
 
b3c97dd
03bc240
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import gradio as gr
from docx import Document
import fitz  # PyMuPDF for PDF text extraction
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import sent_tokenize
import torch
import pickle
import nltk
import faiss
import numpy as np

# Ensure NLTK resources are downloaded
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Initialize the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")

# Define RAG models
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base"  # Can be the same as generator
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)

# Initialize FAISS index using LangChain
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

# Load or create FAISS index
index_path = "faiss_index.index"
if os.path.exists(index_path):
    faiss_index = faiss.read_index(index_path)
    print("Loaded FAISS index from faiss_index.index")
else:
    # Create a new FAISS index
    d = embedding_model.get_sentence_embedding_dimension()  # Dimension of the embeddings
    faiss_index = faiss.IndexFlatL2(d)  # Using IndexFlatL2 for simplicity

state = {
    "conversation": [],
    "sentences": []
}

def extract_text_from_pdf(pdf_path):
    text = ""
    try:
        doc = fitz.open(pdf_path)
        for page_num in range(len(doc)):
            page = doc.load_page(page_num)
            text += page.get_text()
    except Exception as e:
        raise RuntimeError(f"Error extracting text from PDF '{pdf_path}': {e}")
    return text

def extract_text_from_docx(docx_path):
    text = ""
    try:
        doc = Document(docx_path)
        text = "\n".join([para.text for para in doc.paragraphs])
    except Exception as e:
        raise RuntimeError(f"Error extracting text from DOCX '{docx_path}': {e}")
    return text

def preprocess_text(text):
    sentences = sent_tokenize(text)
    return sentences

def upload_files(files):
    global state, faiss_index
    try:
        for file in files:
            try:
                if isinstance(file, str):
                    file_path = file
                else:
                    file_path = file.name

                if file_path.endswith('.pdf'):
                    text = extract_text_from_pdf(file_path)
                elif file_path.endswith('.docx'):
                    text = extract_text_from_docx(file_path)
                else:
                    return {"error": f"Unsupported file format: {file_path}"}

                sentences = preprocess_text(text)
                embeddings = embedding_model.encode(sentences)
                
                faiss_index.add(np.array(embeddings).astype(np.float32))  # Add embeddings
                state["sentences"].extend(sentences)

            except Exception as e:
                print(f"Error processing file '{file}': {e}")
                return {"error": str(e)}

        # Save the updated index
        faiss.write_index(faiss_index, index_path)

        return {"message": "Files processed successfully"}
    
    except Exception as e:
        print(f"General error processing files: {e}")
        return {"error": str(e)}

def process_and_query(question):
    global state, faiss_index
    if not question:
        return {"error": "No question provided"}

    try:
        question_embedding = embedding_model.encode([question])

        # Perform FAISS search
        D, I = faiss_index.search(np.array(question_embedding).astype(np.float32), k=5)
        retrieved_results = [state["sentences"][i] for i in I[0] if i != -1]  # Ensure valid indices

        # Generate response based on retrieved results
        context = " ".join(retrieved_results)
        
        # Enhanced prompt template
        prompt_template = """
        Answer the question as detailed as possible from the provided context,
        make sure to provide all the details, if the answer is not in
        provided context just say, "answer is not available in the context",
        don't provide the wrong answer

        Context:\n{context}

        Question: \n{question}

        Answer:
        --------------------------------------------------
        Prompt Suggestions:
        1. Summarize the primary theme of the context.
        2. Elaborate on the crucial concepts highlighted in the context.
        3. Pinpoint any supporting details or examples pertinent to the question.
        4. Examine any recurring themes or patterns relevant to the question within the context.
        5. Contrast differing viewpoints or elements mentioned in the context.
        6. Explore the potential implications or outcomes of the information provided.
        7. Assess the trustworthiness and validity of the information given.
        8. Propose recommendations or advice based on the presented information.
        9. Forecast likely future events or results stemming from the context.
        10. Expand on the context or background information pertinent to the question.
        11. Define any specialized terms or technical language used within the context.
        12. Analyze any visual representations like charts or graphs in the context.
        13. Highlight any restrictions or important considerations when responding to the question.
        14. Examine any presuppositions or biases evident within the context.
        15. Present alternate interpretations or viewpoints regarding the information provided.
        16. Reflect on any moral or ethical issues raised by the context.
        17. Investigate any cause-and-effect relationships identified in the context.
        18. Uncover any questions or areas requiring further exploration.
        19. Resolve any vague or conflicting information in the context.
        20. Cite case studies or examples that demonstrate the concepts discussed in the context.
        --------------------------------------------------
        Context:\n{context}

        Question:\n{question}

        Answer:
        """
        
        combined_input = prompt_template.format(context=context, question=question)
        inputs = generator_tokenizer(combined_input, return_tensors="pt", max_length=512, truncation=True)
        with torch.no_grad():
            generator_outputs = generator.generate(**inputs)
            generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)

        # Update conversation history
        state["conversation"].append({"question": question, "answer": generated_text})

        return {"message": generated_text, "conversation": state["conversation"]}

    except Exception as e:
        print(f"Error processing query: {e}")
        return {"error": str(e)}

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Document Upload and Query System")

    with gr.Tab("Upload Files"):
        upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
        upload_button = gr.Button("Upload")
        upload_output = gr.Textbox()
        upload_button.click(fn=upload_files, inputs=[upload], outputs=upload_output)

    with gr.Tab("Query"):
        query = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Search")
        query_output = gr.Textbox()
        query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)

demo.launch()