Abhinav Gavireddi
fix: fixed bugs in UI
04db7e0
"""
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation.
This module contains:
- Retriever: Hybrid BM25 + dense retrieval over parsed chunks
- Reranker: Cross-encoder based re-ranking of candidate chunks
- AnswerGenerator: ties together retrieval, re-ranking, and LLM generation
Each component is modular and can be swapped or extended (e.g., add HyDE retriever).
"""
import os
from typing import List, Dict, Any, Tuple
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from src import sanitize_html
from src.utils import LLMClient, logger
from src.retriever import Retriever, RetrieverConfig
class RerankerConfig:
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class Reranker:
"""
Cross-encoder re-ranker using a transformer-based sequence classification model.
"""
def __init__(self, config: RerankerConfig):
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME)
self.model.to(config.DEVICE)
except Exception as e:
logger.error(f'Failed to load reranker model: {e}')
raise
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]:
"""Score each candidate and return top_k sorted by relevance."""
if not candidates:
logger.warning('No candidates provided to rerank.')
return []
try:
inputs = self.tokenizer(
[query] * len(candidates),
[c.get('narration', '') for c in candidates],
padding=True,
truncation=True,
return_tensors='pt'
).to(RerankerConfig.DEVICE)
with torch.no_grad():
out = self.model(**inputs)
logits = out.logits
if logits.ndim == 2 and logits.shape[1] == 1:
logits = logits.squeeze(-1) # only squeeze if it's (batch, 1)
probs = torch.sigmoid(logits).cpu().numpy().flatten() # flatten always ensures 1D array
paired = []
for idx, c in enumerate(candidates):
score = float(probs[idx])
paired.append((c, score))
ranked = sorted(paired, key=lambda x: x[1], reverse=True)
return [c for c, _ in ranked[:top_k]]
except Exception as e:
logger.error(f'Reranking failed: {e}')
return candidates[:top_k]
class AnswerGenerator:
"""
Main interface: initializes Retriever + Reranker once, then
answers multiple questions without re-loading models each time.
"""
def __init__(self, chunks: List[Dict[str, Any]]):
self.chunks = chunks
self.retriever = Retriever(chunks, RetrieverConfig)
self.reranker = Reranker(RerankerConfig)
self.top_k = RetrieverConfig.TOP_K // 2
def answer(
self, question: str
) -> Tuple[str, List[Dict[str, Any]]]:
candidates = self.retriever.retrieve(question)
top_chunks = self.reranker.rerank(question, candidates, self.top_k)
context = "\n\n".join(f"- {c['narration']}" for c in top_chunks)
prompt = (
"You are a knowledgeable assistant. Use the following snippets to answer."
f"\n\nContext information is below: \n"
'------------------------------------'
f"{context}"
'------------------------------------'
"Given the context information above I want you \n"
"to think step by step to answer the query in a crisp \n"
"manner, incase you don't have enough information, \n"
"just say I don't know!. \n\n"
f"\n\nQuestion: {question} \n"
"Answer:"
)
answer = LLMClient.generate(prompt)
return answer, top_chunks