Financial_RAG / app.py
alfa95's picture
app.py
b907e11
import os
import re
import faiss
import numpy as np
import requests
import pdfplumber
import spacy
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
import gradio as gr
# βœ… Load Models
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
# βœ… Load API Key from Hugging Face Secrets
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
raise ValueError("🚨 Please set the Google API Key in Hugging Face Secrets!")
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
# βœ… Financial Keywords for Filtering
FINANCIAL_KEYWORDS = [
"revenue", "profit", "loss", "balance sheet", "cash flow",
"earnings", "expenses", "investment", "financial", "liability",
"assets", "equity", "debt", "capital", "tax", "dividends",
"reserves", "net income", "operating income"
]
# βœ… Global Variables for FAISS & BM25
bm25, chunk_texts, faiss_index = None, [], None
# πŸ”Ή 1. Extract and Clean Text from PDF
def extract_text_from_pdf(pdf_path):
text = ""
with pdfplumber.open(pdf_path) as pdf:
for page in pdf.pages:
extracted = page.extract_text()
if extracted:
text += extracted + "\n"
return clean_text(text)
# πŸ”Ή 2. Clean Extracted Text
def clean_text(text):
text = re.sub(r"https?://\S+", "", text) # Remove URLs
text = re.sub(r"^\d{2}/\d{2}/\d{4}.*$", "", text, flags=re.MULTILINE) # Remove timestamps
text = re.sub(r"(?i)this data can be easily copy pasted.*?", "", text, flags=re.MULTILINE) # Remove metadata
text = re.sub(r"(?i)moneycontrol.com.*?", "", text, flags=re.MULTILINE) # Remove source attribution
text = re.sub(r"(\n\s*)+", "\n", text) # Remove extra blank lines
return text.strip()
# πŸ”Ή 3. Chunking Extracted Text
def chunk_text(text, max_tokens=64):
doc = nlp(text)
sentences = [sent.text for sent in doc.sents]
chunks, current_chunk = [], []
token_count = 0
for sentence in sentences:
tokens = sentence.split()
if token_count + len(tokens) > max_tokens:
chunks.append(" ".join(current_chunk))
current_chunk = []
token_count = 0
current_chunk.append(sentence)
token_count += len(tokens)
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
# πŸ”Ή 4. Store Chunks in FAISS & BM25
def store_in_faiss(chunks):
global bm25, chunk_texts, faiss_index
embeddings = embed_model.encode(chunks, convert_to_numpy=True)
# Create FAISS index
faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
faiss_index.add(embeddings)
chunk_texts = chunks
bm25 = BM25Okapi([chunk.split() for chunk in chunks])
return faiss_index
# πŸ”Ή 5. Retrieve Chunks using BM25 with Scores
def retrieve_bm25(query, top_k=2):
tokenized_query = query.split()
scores = bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[-top_k:][::-1] # Get top indices
# Normalize BM25 scores
min_score, max_score = np.min(scores), np.max(scores)
normalized_scores = [(scores[i] - min_score) / (max_score - min_score) if max_score != min_score else 1 for i in top_indices]
retrieved_chunks = [(chunk_texts[i], normalized_scores[idx]) for idx, i in enumerate(top_indices)]
return retrieved_chunks
# πŸ”Ή 6. Generate Response Using Google Gemini
def refine_with_gemini(query, retrieved_text):
if not retrieved_text.strip():
return "❌ No relevant financial data found for your query."
payload = {
"contents": [{
"parts": [{
"text": f"You are an expert financial analyst. Based on the provided data, extract only the relevant financial details related to the query: '{query}' and present them in a clear format.\n\nData:\n{retrieved_text}"
}]
}]
}
try:
response = requests.post(
f"{GEMINI_API_URL}?key={GEMINI_API_KEY}",
json=payload, headers={"Content-Type": "application/json"}
)
response_json = response.json()
if response.status_code != 200:
print("🚨 Gemini API Error Response:", response_json)
return f"⚠️ Gemini API Error: {response_json.get('error', {}).get('message', 'Unknown error')}"
print("βœ… Gemini API Response:", response_json)
return response_json.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "⚠️ Error generating response.")
except Exception as e:
print("🚨 Exception in Gemini API Call:", str(e))
return "⚠️ Gemini API Exception: Unable to fetch response."
# πŸ”Ή 7. Final Retrieval Function with Confidence Score
def retrieve_and_generate_secure(query):
print("πŸ” Query Received:", query)
if bm25 is None or not chunk_texts:
return "❌ No PDF data loaded. Please upload a PDF first."
bm25_results = retrieve_bm25(query)
if not bm25_results:
return "❌ No relevant financial data found for your query."
# Extract text and confidence scores
retrieved_texts, bm25_confidences = zip(*bm25_results)
# Average BM25 Confidence Score
avg_bm25_confidence = sum(bm25_confidences) / len(bm25_confidences)
# Get FAISS Similarity Score
query_embedding = embed_model.encode([query])
D, I = faiss_index.search(query_embedding, 1) # Top-1 FAISS retrieval
faiss_confidence = 1 / (1 + D[0][0]) if D[0][0] != 0 else 1 # Convert distance to similarity
# Combine Confidence Scores (Weighted Average)
final_confidence = (0.6 * avg_bm25_confidence) + (0.4 * faiss_confidence)
# Generate Final Answer
final_answer = refine_with_gemini(query, "\n".join(retrieved_texts))
return f"πŸ’¬ Answer: {final_answer}\n\nπŸ”Ή Confidence Score: {round(final_confidence * 100, 2)}%"
# πŸ”Ή 8. Load PDF and Process Data
def process_uploaded_pdf(pdf_file):
global faiss_index
text = extract_text_from_pdf(pdf_file.name)
chunks = chunk_text(text)
faiss_index = store_in_faiss(chunks)
return "βœ… PDF Processed Successfully! Now you can ask financial questions."
# πŸ”Ή 9. Build Gradio UI
with gr.Blocks() as app:
gr.Markdown("# πŸ“Š Financial RAG Model")
gr.Markdown("Upload a company financial report PDF and ask relevant financial questions.")
with gr.Row():
pdf_input = gr.File(label="πŸ“‚ Upload Financial PDF", type="filepath")
process_button = gr.Button("πŸ“œ Process PDF")
status_output = gr.Textbox(label="Processing Status", interactive=False)
with gr.Row():
query_input = gr.Textbox(label="❓ Ask a financial question")
answer_output = gr.Textbox(label="πŸ’¬ Answer", interactive=False)
query_button = gr.Button("πŸ” Get Answer")
# Events
process_button.click(process_uploaded_pdf, inputs=pdf_input, outputs=status_output)
query_button.click(retrieve_and_generate_secure, inputs=query_input, outputs=answer_output)
# πŸ”Ή 10. Launch UI
app.launch()