File size: 4,600 Bytes
1649416
944d263
24d9947
834c71a
24d9947
 
d87413b
24d9947
145a282
24d9947
9502a66
 
 
3ac4e4b
 
 
 
 
 
 
 
 
 
 
24d9947
9502a66
d87413b
9502a66
d87413b
 
 
 
 
 
 
 
 
9502a66
d87413b
9502a66
24d9947
 
 
9502a66
24d9947
 
 
 
 
 
 
 
3ac4e4b
 
24d9947
3ac4e4b
56ec544
9502a66
 
 
d87413b
944d263
 
d87413b
24d9947
944d263
 
24d9947
944d263
834c71a
d87413b
944d263
 
d87413b
944d263
 
a028e27
 
 
9502a66
 
 
a028e27
 
 
 
 
 
 
d87413b
a028e27
d87413b
 
 
 
a028e27
 
 
 
 
 
d87413b
a028e27
9502a66
 
 
d87413b
 
9502a66
a028e27
 
 
d87413b
 
 
 
a028e27
9502a66
 
 
 
 
 
a028e27
9502a66
 
 
 
 
 
a028e27
 
 
 
d87413b
9502a66
a028e27
 
 
d87413b
9502a66
d87413b
9502a66
 
a028e27
 
d87413b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import pickle
import numpy as np
import gradio as gr
import fitz  # PyMuPDF
from docx import Document
from transformers import AutoModel, AutoTokenizer, pipeline
import faiss
import torch

# ===============================
# EMBEDDING MODEL
# ===============================
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
embedding_model = AutoModel.from_pretrained(model_name)

def get_embeddings(texts):
    if isinstance(texts, str):
        texts = [texts]
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
    with torch.no_grad():
        outputs = embedding_model(**inputs)
    return outputs.last_hidden_state[:, 0].cpu().numpy()

# ===============================
# TEXT CHUNKING
# ===============================
def chunk_text(text, chunk_size=500, overlap=50):
    chunks = []
    start = 0
    while start < len(text):
        end = min(len(text), start + chunk_size)
        chunks.append(text[start:end])
        start += chunk_size - overlap
    return chunks

# ===============================
# FAISS INDEX SETUP
# ===============================
index_path = "faiss_index.pkl"
document_texts_path = "document_texts.pkl"
document_texts = []
embedding_dim = 384

if os.path.exists(index_path) and os.path.exists(document_texts_path):
    try:
        with open(index_path, "rb") as f:
            index = pickle.load(f)
        with open(document_texts_path, "rb") as f:
            document_texts = pickle.load(f)
    except Exception as e:
        print(f"Error loading index: {e}")
        index = faiss.IndexFlatIP(embedding_dim)
else:
    index = faiss.IndexFlatIP(embedding_dim)

# ===============================
# FILE EXTRACTORS
# ===============================
def extract_text_from_pdf(path):
    text = ""
    try:
        doc = fitz.open(path)
        for page in doc:
            text += page.get_text()
    except Exception as e:
        print(f"PDF error: {e}")
    return text

def extract_text_from_docx(path):
    text = ""
    try:
        doc = Document(path)
        text = "\n".join([para.text for para in doc.paragraphs])
    except Exception as e:
        print(f"DOCX error: {e}")
    return text

# ===============================
# UPLOAD HANDLER
# ===============================
def upload_document(file):
    ext = os.path.splitext(file.name)[-1].lower()
    if ext == ".pdf":
        text = extract_text_from_pdf(file.name)
    elif ext == ".docx":
        text = extract_text_from_docx(file.name)
    else:
        return "Unsupported file type."

    chunks = chunk_text(text)
    chunk_embeddings = get_embeddings(chunks)
    index.add(np.array(chunk_embeddings).astype('float32'))
    document_texts.extend(chunks)

    with open(index_path, "wb") as f:
        pickle.dump(index, f)
    with open(document_texts_path, "wb") as f:
        pickle.dump(document_texts, f)

    return "Document uploaded and indexed successfully."

# ===============================
# GENERATION PIPELINE (FLAN-T5)
# ===============================
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")

def generate_answer_from_file(query, top_k=5):
    if not document_texts:
        return "No documents indexed yet."

    query_vector = get_embeddings(query).astype("float32")
    scores, indices = index.search(query_vector, k=top_k)
    retrieved_chunks = [document_texts[i] for i in indices[0]]
    context = " ".join(retrieved_chunks)

    prompt = (
        f"Use the following context from a textbook or academic document to answer the question accurately and in detail.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {query}\n\n"
        f"Answer:"
    )

    result = qa_pipeline(prompt, max_length=512, do_sample=False)[0]['generated_text']
    return result.strip()

# ===============================
# GRADIO INTERFACES
# ===============================
upload_interface = gr.Interface(
    fn=upload_document,
    inputs=gr.File(file_types=[".pdf", ".docx"]),
    outputs="text",
    title="Upload Document",
    description="Upload your Word or PDF document for question answering."
)

search_interface = gr.Interface(
    fn=generate_answer_from_file,
    inputs=gr.Textbox(placeholder="Ask your question about the uploaded document..."),
    outputs="text",
    title="Ask the Document",
    description="Ask questions about the uploaded content. The chatbot will answer based on the document."
)

app = gr.TabbedInterface([upload_interface, search_interface], ["Upload", "Ask"])
app.launch()