|
""" |
|
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation. |
|
|
|
This module contains: |
|
- Retriever: Hybrid BM25 + dense retrieval over parsed chunks |
|
- Reranker: Cross-encoder based re-ranking of candidate chunks |
|
- AnswerGenerator: ties together retrieval, re-ranking, and LLM generation |
|
|
|
Each component is modular and can be swapped or extended (e.g., add HyDE retriever). |
|
""" |
|
import os |
|
from typing import List, Dict, Any, Tuple |
|
|
|
from sentence_transformers import SentenceTransformer |
|
from rank_bm25 import BM25Okapi |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
from src import sanitize_html |
|
from src.utils import LLMClient, logger |
|
from src.retriever import Retriever, RetrieverConfig |
|
|
|
|
|
class RerankerConfig: |
|
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma') |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
class Reranker: |
|
""" |
|
Cross-encoder re-ranker using a transformer-based sequence classification model. |
|
""" |
|
def __init__(self, config: RerankerConfig): |
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME) |
|
self.model.to(config.DEVICE) |
|
except Exception as e: |
|
logger.error(f'Failed to load reranker model: {e}') |
|
raise |
|
|
|
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: |
|
"""Score each candidate and return top_k sorted by relevance.""" |
|
if not candidates: |
|
logger.warning('No candidates provided to rerank.') |
|
return [] |
|
try: |
|
inputs = self.tokenizer( |
|
[query] * len(candidates), |
|
[c.get('narration', '') for c in candidates], |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt' |
|
).to(RerankerConfig.DEVICE) |
|
with torch.no_grad(): |
|
out = self.model(**inputs) |
|
|
|
logits = out.logits |
|
if logits.ndim == 2 and logits.shape[1] == 1: |
|
logits = logits.squeeze(-1) |
|
|
|
probs = torch.sigmoid(logits).cpu().numpy().flatten() |
|
paired = [] |
|
for idx, c in enumerate(candidates): |
|
score = float(probs[idx]) |
|
paired.append((c, score)) |
|
|
|
ranked = sorted(paired, key=lambda x: x[1], reverse=True) |
|
return [c for c, _ in ranked[:top_k]] |
|
except Exception as e: |
|
logger.error(f'Reranking failed: {e}') |
|
return candidates[:top_k] |
|
|
|
|
|
class AnswerGenerator: |
|
""" |
|
Main interface: initializes Retriever + Reranker once, then |
|
answers multiple questions without re-loading models each time. |
|
""" |
|
def __init__(self, chunks: List[Dict[str, Any]]): |
|
self.chunks = chunks |
|
self.retriever = Retriever(chunks, RetrieverConfig) |
|
self.reranker = Reranker(RerankerConfig) |
|
self.top_k = RetrieverConfig.TOP_K // 2 |
|
|
|
def answer( |
|
self, question: str |
|
) -> Tuple[str, List[Dict[str, Any]]]: |
|
candidates = self.retriever.retrieve(question) |
|
top_chunks = self.reranker.rerank(question, candidates, self.top_k) |
|
context = "\n\n".join(f"- {c['narration']}" for c in top_chunks) |
|
prompt = ( |
|
"You are a knowledgeable assistant. Use the following snippets to answer." |
|
f"\n\nContext information is below: \n" |
|
'------------------------------------' |
|
f"{context}" |
|
'------------------------------------' |
|
"Given the context information above I want you \n" |
|
"to think step by step to answer the query in a crisp \n" |
|
"manner, incase you don't have enough information, \n" |
|
"just say I don't know!. \n\n" |
|
f"\n\nQuestion: {question} \n" |
|
"Answer:" |
|
) |
|
answer = LLMClient.generate(prompt) |
|
return answer, top_chunks |
|
|
|
|