File size: 3,803 Bytes
409f81b
2c02a9e
409f81b
 
 
 
 
84f3457
8ceb607
47ecda0
 
 
f812db9
 
2c02a9e
f812db9
2c02a9e
f812db9
2c02a9e
f812db9
2c02a9e
f812db9
409f81b
f812db9
409f81b
 
f812db9
 
409f81b
 
 
2c02a9e
47ecda0
f812db9
 
 
409f81b
f812db9
 
 
 
 
 
 
 
409f81b
261cad3
ba470cd
f812db9
 
 
 
 
ba470cd
409f81b
2c02a9e
8ceb607
6e6d28c
70fd172
f812db9
 
 
 
 
 
ba470cd
 
f812db9
 
 
6e6d28c
ba470cd
6e6d28c
 
2c02a9e
f812db9
2c02a9e
ba470cd
f812db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba470cd
 
 
2c02a9e
f812db9
 
2c02a9e
f812db9
 
 
2c02a9e
261cad3
ba470cd
6e6d28c
70fd172
47ecda0
 
 
0385c04
 
84f3457
 
 
409f81b
d7100c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import fitz
from docx import Document
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle
import gradio as gr
from typing import List
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from nltk.tokenize import sent_tokenize  # Import for sentence segmentation
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Function to extract text from a PDF file (same as before)
def extract_text_from_pdf(pdf_path):
    # ...

# Function to extract text from a Word document (same as before)
def extract_text_from_docx(docx_path):
    # ...

# Initialize the embedding model (same as before)
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')


# Hugging Face API token (same as before)
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")


# Define RAG models (replace with your chosen models)
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base"  # Can be the same as generator

generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)

retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)


# Load or create FAISS index (same as before)
index_path = "faiss_index.pkl"
document_texts_path = "document_texts.pkl"
document_texts = []
# ... (rest of the FAISS index loading logic)


def preprocess_text(text):
    # ... (text preprocessing logic, same as before)


def upload_files(files):
    global index, document_texts
    try:
        for file_path in files:
            # ... (file processing logic, same as before)

            # Preprocess text (call the new function)
            sentences = preprocess_text(text)

            # Encode sentences and add to FAISS index
            embeddings = embedding_model.encode(sentences)
            index.add(np.array(embeddings))

        # Save the updated index and documents (same as before)
        # ...
        return "Files processed successfully"
    except Exception as e:
        print(f"Error processing files: {e}")
        return f"Error processing files: {e}"


def query_text(text):
    try:
        # Preprocess query text
        query_sentences = preprocess_text(text)
        query_embeddings = embedding_model.encode(query_sentences)

        # Retrieve relevant documents using FAISS
        D, I = index.search(np.array(query_embeddings), k=5)
        retrieved_docs = [document_texts[idx] for idx in I[0] if idx != -1]

        # Retriever-Augmented Generation (RAG)
        retriever_inputs = retriever_tokenizer(
            text=retrieved_docs, return_tensors="pt", padding=True
        )
        retriever_outputs = retriever(**retriever_inputs)
        retrieved_texts = retriever_tokenizer.batch_decode(retriever_outputs.logits)

        # Generate response using retrieved information (as prompts/context)
        generator_inputs = generator_tokenizer(
            text=[text] + retrieved_texts, return_tensors="pt", padding=True
        )
        generator_outputs = generator(**generator_inputs)
        response = generator_tokenizer.decode(generator_outputs.sequences[0], skip_special_tokens=True)

        return response
    except Exception as e:
        print(f"Error querying text: {e}")
        return f"Error querying text: {e}"


# Create Gradio interface 
with gr.Blocks() as demo:
    # ... (rest of the Gradio interface definition)
    query_button.click(fn=query_text, inputs