Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
import os | |
from dotenv import load_dotenv | |
from flask_cors import CORS | |
import base64 | |
import tempfile | |
import io | |
from pathlib import Path | |
# Load environment variables | |
load_dotenv() | |
app = Flask(__name__) | |
CORS(app) | |
# Increase maximum content length to 32MB | |
app.config['MAX_CONTENT_LENGTH'] = 32 * 1024 * 1024 | |
# Global variables | |
qa_chain = None | |
vector_db = None | |
api_token =os.getenv("HF_TOKEN") | |
pdf_chunks = {} | |
app.config['UPLOAD_FOLDER'] = 'temp_uploads' | |
# Create upload folder if it doesn't exist | |
Path(app.config['UPLOAD_FOLDER']).mkdir(parents=True, exist_ok=True) | |
# Available LLM models | |
LLM_MODELS = { | |
"llama": "meta-llama/Meta-Llama-3-8B-Instruct", | |
"mistral": "mistralai/Mistral-7B-Instruct-v0.2" | |
} | |
# Add these global variables | |
current_upload = { | |
'filename': None, | |
'chunks': [], | |
'filesize': 0 | |
} | |
def load_doc(file_paths): | |
"""Load and split multiple PDF documents""" | |
loaders = [PyPDFLoader(path) for path in file_paths] | |
pages = [] | |
for loader in loaders: | |
pages.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1024, | |
chunk_overlap=64 | |
) | |
doc_splits = text_splitter.split_documents(pages) | |
return doc_splits | |
def create_db(splits): | |
"""Create vector database from document splits""" | |
embeddings = HuggingFaceEmbeddings() | |
vectordb = FAISS.from_documents(splits, embeddings) | |
return vectordb | |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): | |
"""Initialize the LLM chain""" | |
llm = HuggingFaceEndpoint( | |
repo_id=llm_model, | |
huggingfacehub_api_token=api_token, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k, | |
) | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key='answer', | |
return_messages=True | |
) | |
retriever = vector_db.as_retriever() | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
chain_type="stuff", | |
memory=memory, | |
return_source_documents=True, | |
verbose=False, | |
) | |
return qa_chain | |
def format_chat_history(message, chat_history): | |
"""Format chat history for the LLM""" | |
formatted_chat_history = [] | |
for user_message, bot_message in chat_history: | |
formatted_chat_history.append(f"User: {user_message}") | |
formatted_chat_history.append(f"Assistant: {bot_message}") | |
return formatted_chat_history | |
def upload_pdf(): | |
"""Handle PDF upload and database initialization""" | |
global vector_db | |
if 'pdf_base64' not in request.json: | |
return jsonify({'error': 'No PDF data provided'}), 400 | |
try: | |
# Get base64 PDF and filename | |
pdf_base64 = request.json['pdf_base64'] | |
filename = request.json.get('filename', 'uploaded.pdf') | |
# Create temp directory if it doesn't exist | |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
temp_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
try: | |
# Decode and save PDF | |
pdf_data = base64.b64decode(pdf_base64) | |
with open(temp_path, 'wb') as f: | |
f.write(pdf_data) | |
# Process document | |
doc_splits = load_doc([temp_path]) | |
vector_db = create_db(doc_splits) | |
return jsonify({'message': 'PDF processed successfully'}), 200 | |
finally: | |
# Clean up | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def init_llm(): | |
"""Initialize the LLM with parameters""" | |
global qa_chain, vector_db | |
if vector_db is None: | |
return jsonify({'error': 'Please upload PDFs first'}), 400 | |
data = request.json | |
model_name = data.get('model', 'llama') # default to llama | |
temperature = data.get('temperature', 0.5) | |
max_tokens = data.get('max_tokens', 4096) | |
top_k = data.get('top_k', 3) | |
if model_name not in LLM_MODELS: | |
return jsonify({'error': 'Invalid model name'}), 400 | |
try: | |
qa_chain = initialize_llmchain( | |
LLM_MODELS[model_name], | |
temperature, | |
max_tokens, | |
top_k, | |
vector_db | |
) | |
return jsonify({'message': 'LLM initialized successfully'}), 200 | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def chat(): | |
"""Handle chat interactions""" | |
global qa_chain | |
if qa_chain is None: | |
return jsonify({'error': 'LLM not initialized'}), 400 | |
data = request.json | |
question = data.get('question') | |
chat_history = data.get('chat_history', []) | |
if not question: | |
return jsonify({'error': 'No question provided'}), 400 | |
try: | |
formatted_history = format_chat_history(question, chat_history) | |
result = qa_chain({"question": question, "chat_history": formatted_history}) | |
# Process the response | |
answer = result['answer'] | |
if "Helpful Answer:" in answer: | |
answer = answer.split("Helpful Answer:")[-1] | |
# Extract sources | |
sources = [] | |
for doc in result['source_documents'][:3]: | |
sources.append({ | |
'content': doc.page_content.strip(), | |
'page': doc.metadata.get('page', 0) + 1 # Convert to 1-based page numbers | |
}) | |
response = { | |
'answer': answer, | |
'sources': sources | |
} | |
return jsonify(response), 200 | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def upload_local(): | |
"""Handle PDF upload from local file system""" | |
global vector_db | |
data = request.json | |
file_path = data.get('file_path') | |
if not file_path or not os.path.exists(file_path): | |
return jsonify({'error': 'File not found'}), 400 | |
try: | |
# Process document | |
doc_splits = load_doc([file_path]) | |
vector_db = create_db(doc_splits) | |
return jsonify({'message': 'PDF processed successfully'}), 200 | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def start_upload(): | |
"""Initialize a new file upload""" | |
global current_upload | |
data = request.json | |
current_upload = { | |
'filename': data['filename'], | |
'chunks': [], | |
'filesize': data['filesize'] | |
} | |
return jsonify({'message': 'Upload started'}), 200 | |
def upload_chunk(): | |
"""Handle a chunk of the file""" | |
global current_upload | |
if not current_upload['filename']: | |
return jsonify({'error': 'No upload in progress'}), 400 | |
try: | |
chunk = base64.b64decode(request.json['chunk']) | |
current_upload['chunks'].append(chunk) | |
return jsonify({'message': 'Chunk received'}), 200 | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def finish_upload(): | |
"""Process the complete file""" | |
global current_upload, vector_db | |
if not current_upload['filename']: | |
return jsonify({'error': 'No upload in progress'}), 400 | |
try: | |
# Create temp directory if it doesn't exist | |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
temp_path = os.path.join(app.config['UPLOAD_FOLDER'], current_upload['filename']) | |
# Combine chunks and save file | |
with open(temp_path, 'wb') as f: | |
for chunk in current_upload['chunks']: | |
f.write(chunk) | |
# Process the PDF | |
doc_splits = load_doc([temp_path]) | |
vector_db = create_db(doc_splits) | |
# Cleanup | |
os.remove(temp_path) | |
current_upload['chunks'] = [] | |
current_upload['filename'] = None | |
return jsonify({'message': 'PDF processed successfully'}), 200 | |
except Exception as e: | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
return jsonify({'error': str(e)}), 500 | |
if __name__ == '__main__': | |
app.run(debug=True) | |