Spaces:
Running
Running
from flask import Flask, request, jsonify, render_template | |
import fitz # PyMuPDF for PDF text extraction | |
import faiss # FAISS for vector search | |
import numpy as np | |
import os | |
from sentence_transformers import SentenceTransformer | |
from huggingface_hub import InferenceClient | |
from typing import List, Tuple | |
app = Flask(__name__, template_folder=os.getcwd()) | |
# Default settings | |
class ChatConfig: | |
MODEL = "google/gemma-3-27b-it" | |
DEFAULT_SYSTEM_MSG = "You are an AI assistant answering only based on the uploaded PDF." | |
DEFAULT_MAX_TOKENS = 512 | |
DEFAULT_TEMP = 0.3 | |
DEFAULT_TOP_P = 0.95 | |
client = InferenceClient(ChatConfig.MODEL) | |
embed_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp") | |
vector_dim = 384 # Embedding size | |
index = faiss.IndexFlatL2(vector_dim) # FAISS index | |
documents = [] # Store extracted text | |
def extract_text_from_pdf(pdf_stream): | |
"""Extracts text from PDF stream""" | |
doc = fitz.open(stream=pdf_stream, filetype="pdf") | |
text_chunks = [page.get_text("text") for page in doc] | |
doc.close() | |
return text_chunks | |
def create_vector_db(text_chunks): | |
"""Embeds text chunks and adds them to FAISS index""" | |
global documents, index | |
# Reinitialize the FAISS index | |
index = faiss.IndexFlatL2(vector_dim) | |
documents = text_chunks | |
embeddings = embed_model.encode(text_chunks) | |
# Convert embeddings to np.float32 for FAISS | |
embeddings = np.array(embeddings, dtype=np.float32) | |
# Ensure that embeddings have the correct shape (should be 2D, with each vector having the right dimension) | |
if embeddings.ndim == 1: # If only one embedding, reshape it | |
embeddings = embeddings.reshape(1, -1) | |
# Add embeddings to the FAISS index | |
index.add(embeddings) | |
# Check if adding was successful (optional) | |
if index.ntotal == 0: | |
print("Error: FAISS index is empty after adding embeddings.") | |
def search_relevant_text(query): | |
"""Finds the most relevant text chunk for the given query""" | |
query_embedding = embed_model.encode([query]) | |
_, closest_idx = index.search(np.array(query_embedding, dtype=np.float32), k=3) | |
return "\n".join([documents[i] for i in closest_idx[0]]) | |
def generate_response( | |
message: str, | |
history: List[Tuple[str, str]], | |
system_message: str = ChatConfig.DEFAULT_SYSTEM_MSG, | |
max_tokens: int = ChatConfig.DEFAULT_MAX_TOKENS, | |
temperature: float = ChatConfig.DEFAULT_TEMP, | |
top_p: float = ChatConfig.DEFAULT_TOP_P | |
) -> str: | |
if not documents: | |
return "Please upload a PDF first." | |
context = search_relevant_text(message) # Get relevant content from PDF | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, bot_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if bot_msg: | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": f"Context: {context}\nQuestion: {message}"}) | |
response = "" | |
for chunk in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = chunk.choices[0].delta.content or "" | |
response += token | |
return response | |
def index(): | |
"""Serve the HTML page for the user interface""" | |
return render_template('index.html') | |
def upload_pdf(): | |
"""Handle PDF upload""" | |
if 'pdf' not in request.files: | |
return jsonify({"error": "No file part"}), 400 | |
file = request.files['pdf'] | |
if file.filename == "": | |
return jsonify({"error": "No selected file"}), 400 | |
try: | |
# Read the file directly into memory instead of saving to disk | |
pdf_stream = file.read() | |
# Create a BytesIO object to work with the PDF in memory | |
from io import BytesIO | |
pdf_stream = BytesIO(pdf_stream) | |
# Use fitz to open the PDF from memory | |
doc = fitz.open(stream=pdf_stream, filetype="pdf") | |
text_chunks = [page.get_text("text") for page in doc] | |
doc.close() | |
# Create vector database | |
create_vector_db(text_chunks) | |
return jsonify({"message": "PDF uploaded and indexed successfully!"}), 200 | |
except Exception as e: | |
return jsonify({"error": f"Error processing file: {str(e)}"}), 500 | |
def ask_question(): | |
"""Handle user question""" | |
message = request.json.get('message') | |
history = request.json.get('history', []) | |
response = generate_response(message, history) | |
return jsonify({"response": response}) | |
if __name__ == '__main__': | |
app.run(debug=True) | |