Adrien
fix
b7b8732
import os
from typing import Any
import re
from loguru import logger
from rag_demo.preprocessing.embed import EmbeddedChunk
from transformers import pipeline
class SourceAnnotator:
def __init__(self):
# Extractive question answering model
self.source_annotator = pipeline(
"question-answering",
model="distilbert/distilbert-base-cased-distilled-squad",
)
def annotate(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str:
sentences = self.split_sentences(response)
annotated_response = ""
for sentence in sentences:
scores = []
for chunk in reranked_chunks:
score = self.source_annotator(sentence, chunk.content)
score["filename"] = chunk.metadata["filename"].split(".pdf")[0]
score["chunk_id"] = chunk.chunk_id
scores.append(score)
# Could also use a score cut-off instead of max()
max_score = max(scores, key=lambda x: x["score"])
annotated_response += f"{sentence} [filename: {max_score['filename']}, chunk_id: {max_score['chunk_id']} "
return annotated_response
def split_sentences(self, text: str) -> list[str]:
pattern = r"(?<=[.!?])\s+(?=[A-Z])"
sentences = re.split(pattern, text)
return [s.strip() for s in sentences if s.strip()]