Spaces:
Sleeping
Sleeping
import os | |
import re | |
import faiss | |
import numpy as np | |
import requests | |
import pdfplumber | |
import spacy | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from rank_bm25 import BM25Okapi | |
import gradio as gr | |
# β Load Models | |
spacy.cli.download("en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2") | |
# β Load API Key from Hugging Face Secrets | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
if not GEMINI_API_KEY: | |
raise ValueError("π¨ Please set the Google API Key in Hugging Face Secrets!") | |
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" | |
# β Financial Keywords for Filtering | |
FINANCIAL_KEYWORDS = [ | |
"revenue", "profit", "loss", "balance sheet", "cash flow", | |
"earnings", "expenses", "investment", "financial", "liability", | |
"assets", "equity", "debt", "capital", "tax", "dividends", | |
"reserves", "net income", "operating income" | |
] | |
# β Global Variables for FAISS & BM25 | |
bm25, chunk_texts, faiss_index = None, [], None | |
# πΉ 1. Extract and Clean Text from PDF | |
def extract_text_from_pdf(pdf_path): | |
text = "" | |
with pdfplumber.open(pdf_path) as pdf: | |
for page in pdf.pages: | |
extracted = page.extract_text() | |
if extracted: | |
text += extracted + "\n" | |
return clean_text(text) | |
# πΉ 2. Clean Extracted Text | |
def clean_text(text): | |
text = re.sub(r"https?://\S+", "", text) # Remove URLs | |
text = re.sub(r"^\d{2}/\d{2}/\d{4}.*$", "", text, flags=re.MULTILINE) # Remove timestamps | |
text = re.sub(r"(?i)this data can be easily copy pasted.*?", "", text, flags=re.MULTILINE) # Remove metadata | |
text = re.sub(r"(?i)moneycontrol.com.*?", "", text, flags=re.MULTILINE) # Remove source attribution | |
text = re.sub(r"(\n\s*)+", "\n", text) # Remove extra blank lines | |
return text.strip() | |
# πΉ 3. Chunking Extracted Text | |
def chunk_text(text, max_tokens=64): | |
doc = nlp(text) | |
sentences = [sent.text for sent in doc.sents] | |
chunks, current_chunk = [], [] | |
token_count = 0 | |
for sentence in sentences: | |
tokens = sentence.split() | |
if token_count + len(tokens) > max_tokens: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = [] | |
token_count = 0 | |
current_chunk.append(sentence) | |
token_count += len(tokens) | |
if current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
return chunks | |
# πΉ 4. Store Chunks in FAISS & BM25 | |
def store_in_faiss(chunks): | |
global bm25, chunk_texts, faiss_index | |
embeddings = embed_model.encode(chunks, convert_to_numpy=True) | |
# Create FAISS index | |
faiss_index = faiss.IndexFlatL2(embeddings.shape[1]) | |
faiss_index.add(embeddings) | |
chunk_texts = chunks | |
bm25 = BM25Okapi([chunk.split() for chunk in chunks]) | |
return faiss_index | |
# πΉ 5. Retrieve Chunks using BM25 with Scores | |
def retrieve_bm25(query, top_k=2): | |
tokenized_query = query.split() | |
scores = bm25.get_scores(tokenized_query) | |
top_indices = np.argsort(scores)[-top_k:][::-1] # Get top indices | |
# Normalize BM25 scores | |
min_score, max_score = np.min(scores), np.max(scores) | |
normalized_scores = [(scores[i] - min_score) / (max_score - min_score) if max_score != min_score else 1 for i in top_indices] | |
retrieved_chunks = [(chunk_texts[i], normalized_scores[idx]) for idx, i in enumerate(top_indices)] | |
return retrieved_chunks | |
# πΉ 6. Generate Response Using Google Gemini | |
def refine_with_gemini(query, retrieved_text): | |
if not retrieved_text.strip(): | |
return "β No relevant financial data found for your query." | |
payload = { | |
"contents": [{ | |
"parts": [{ | |
"text": f"You are an expert financial analyst. Based on the provided data, extract only the relevant financial details related to the query: '{query}' and present them in a clear format.\n\nData:\n{retrieved_text}" | |
}] | |
}] | |
} | |
try: | |
response = requests.post( | |
f"{GEMINI_API_URL}?key={GEMINI_API_KEY}", | |
json=payload, headers={"Content-Type": "application/json"} | |
) | |
response_json = response.json() | |
if response.status_code != 200: | |
print("π¨ Gemini API Error Response:", response_json) | |
return f"β οΈ Gemini API Error: {response_json.get('error', {}).get('message', 'Unknown error')}" | |
print("β Gemini API Response:", response_json) | |
return response_json.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "β οΈ Error generating response.") | |
except Exception as e: | |
print("π¨ Exception in Gemini API Call:", str(e)) | |
return "β οΈ Gemini API Exception: Unable to fetch response." | |
# πΉ 7. Final Retrieval Function with Confidence Score | |
def retrieve_and_generate_secure(query): | |
print("π Query Received:", query) | |
if bm25 is None or not chunk_texts: | |
return "β No PDF data loaded. Please upload a PDF first." | |
bm25_results = retrieve_bm25(query) | |
if not bm25_results: | |
return "β No relevant financial data found for your query." | |
# Extract text and confidence scores | |
retrieved_texts, bm25_confidences = zip(*bm25_results) | |
# Average BM25 Confidence Score | |
avg_bm25_confidence = sum(bm25_confidences) / len(bm25_confidences) | |
# Get FAISS Similarity Score | |
query_embedding = embed_model.encode([query]) | |
D, I = faiss_index.search(query_embedding, 1) # Top-1 FAISS retrieval | |
faiss_confidence = 1 / (1 + D[0][0]) if D[0][0] != 0 else 1 # Convert distance to similarity | |
# Combine Confidence Scores (Weighted Average) | |
final_confidence = (0.6 * avg_bm25_confidence) + (0.4 * faiss_confidence) | |
# Generate Final Answer | |
final_answer = refine_with_gemini(query, "\n".join(retrieved_texts)) | |
return f"π¬ Answer: {final_answer}\n\nπΉ Confidence Score: {round(final_confidence * 100, 2)}%" | |
# πΉ 8. Load PDF and Process Data | |
def process_uploaded_pdf(pdf_file): | |
global faiss_index | |
text = extract_text_from_pdf(pdf_file.name) | |
chunks = chunk_text(text) | |
faiss_index = store_in_faiss(chunks) | |
return "β PDF Processed Successfully! Now you can ask financial questions." | |
# πΉ 9. Build Gradio UI | |
with gr.Blocks() as app: | |
gr.Markdown("# π Financial RAG Model") | |
gr.Markdown("Upload a company financial report PDF and ask relevant financial questions.") | |
with gr.Row(): | |
pdf_input = gr.File(label="π Upload Financial PDF", type="filepath") | |
process_button = gr.Button("π Process PDF") | |
status_output = gr.Textbox(label="Processing Status", interactive=False) | |
with gr.Row(): | |
query_input = gr.Textbox(label="β Ask a financial question") | |
answer_output = gr.Textbox(label="π¬ Answer", interactive=False) | |
query_button = gr.Button("π Get Answer") | |
# Events | |
process_button.click(process_uploaded_pdf, inputs=pdf_input, outputs=status_output) | |
query_button.click(retrieve_and_generate_secure, inputs=query_input, outputs=answer_output) | |
# πΉ 10. Launch UI | |
app.launch() | |