Pamudu13's picture
Update app.py
47994b5 verified
raw
history blame
4.94 kB
from flask import Flask, request, jsonify, render_template
import fitz # PyMuPDF for PDF text extraction
import faiss # FAISS for vector search
import numpy as np
import os
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
from typing import List, Tuple
app = Flask(__name__, template_folder=os.getcwd())
# Default settings
class ChatConfig:
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
DEFAULT_SYSTEM_MSG = "You are an AI assistant answering only based on the uploaded PDF."
DEFAULT_MAX_TOKENS = 512
DEFAULT_TEMP = 0.3
DEFAULT_TOP_P = 0.95
# Get the token from environment variable
HF_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
client = InferenceClient(
ChatConfig.MODEL,
token=HF_TOKEN # Add your Hugging Face token here
)
embed_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp")
vector_dim = 384 # Embedding size
index = faiss.IndexFlatL2(vector_dim) # FAISS index
documents = [] # Store extracted text
def extract_text_from_pdf(pdf_stream):
"""Extracts text from PDF stream"""
doc = fitz.open(stream=pdf_stream, filetype="pdf")
text_chunks = [page.get_text("text") for page in doc]
doc.close()
return text_chunks
def create_vector_db(text_chunks):
"""Embeds text chunks and adds them to FAISS index"""
global documents, index
# Reinitialize the FAISS index
index = faiss.IndexFlatL2(vector_dim)
documents = text_chunks
embeddings = embed_model.encode(text_chunks)
# Convert embeddings to np.float32 for FAISS
embeddings = np.array(embeddings, dtype=np.float32)
# Ensure that embeddings have the correct shape (should be 2D, with each vector having the right dimension)
if embeddings.ndim == 1: # If only one embedding, reshape it
embeddings = embeddings.reshape(1, -1)
# Add embeddings to the FAISS index
index.add(embeddings)
# Check if adding was successful (optional)
if index.ntotal == 0:
print("Error: FAISS index is empty after adding embeddings.")
def search_relevant_text(query):
"""Finds the most relevant text chunk for the given query"""
query_embedding = embed_model.encode([query])
_, closest_idx = index.search(np.array(query_embedding, dtype=np.float32), k=3)
return "\n".join([documents[i] for i in closest_idx[0]])
def generate_response(
message: str,
history: List[Tuple[str, str]],
system_message: str = ChatConfig.DEFAULT_SYSTEM_MSG,
max_tokens: int = ChatConfig.DEFAULT_MAX_TOKENS,
temperature: float = ChatConfig.DEFAULT_TEMP,
top_p: float = ChatConfig.DEFAULT_TOP_P
) -> str:
if not documents:
return "Please upload a PDF first."
context = search_relevant_text(message) # Get relevant content from PDF
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": f"Context: {context}\nQuestion: {message}"})
response = ""
for chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = chunk.choices[0].delta.content or ""
response += token
return response
@app.route('/')
def index():
"""Serve the HTML page for the user interface"""
return render_template('index.html')
@app.route('/upload_pdf', methods=['POST'])
def upload_pdf():
"""Handle PDF upload"""
if 'pdf' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['pdf']
if file.filename == "":
return jsonify({"error": "No selected file"}), 400
try:
# Read the file directly into memory instead of saving to disk
pdf_stream = file.read()
# Create a BytesIO object to work with the PDF in memory
from io import BytesIO
pdf_stream = BytesIO(pdf_stream)
# Use fitz to open the PDF from memory
doc = fitz.open(stream=pdf_stream, filetype="pdf")
text_chunks = [page.get_text("text") for page in doc]
doc.close()
# Create vector database
create_vector_db(text_chunks)
return jsonify({"message": "PDF uploaded and indexed successfully!"}), 200
except Exception as e:
return jsonify({"error": f"Error processing file: {str(e)}"}), 500
@app.route('/ask_question', methods=['POST'])
def ask_question():
"""Handle user question"""
message = request.json.get('message')
history = request.json.get('history', [])
response = generate_response(message, history)
return jsonify({"response": response})
if __name__ == '__main__':
app.run(debug=True)