import os from typing import Any import re from loguru import logger from rag_demo.preprocessing.embed import EmbeddedChunk from transformers import pipeline class SourceAnnotator: def __init__(self): # Extractive question answering model self.source_annotator = pipeline( "question-answering", model="distilbert/distilbert-base-cased-distilled-squad", ) def annotate(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str: sentences = self.split_sentences(response) annotated_response = "" for sentence in sentences: scores = [] for chunk in reranked_chunks: score = self.source_annotator(sentence, chunk.content) score["filename"] = chunk.metadata["filename"].split(".pdf")[0] score["chunk_id"] = chunk.chunk_id scores.append(score) # Could also use a score cut-off instead of max() max_score = max(scores, key=lambda x: x["score"]) annotated_response += f"{sentence} [filename: {max_score['filename']}, chunk_id: {max_score['chunk_id']} " return annotated_response def split_sentences(self, text: str) -> list[str]: pattern = r"(?<=[.!?])\s+(?=[A-Z])" sentences = re.split(pattern, text) return [s.strip() for s in sentences if s.strip()]