File size: 1,403 Bytes
cc3f1e1 bccb279 cc3f1e1 c56c0fd cc3f1e1 b7b8732 cc3f1e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import os
from typing import Any
import re
from loguru import logger
from rag_demo.preprocessing.embed import EmbeddedChunk
from transformers import pipeline
class SourceAnnotator:
def __init__(self):
# Extractive question answering model
self.source_annotator = pipeline(
"question-answering",
model="distilbert/distilbert-base-cased-distilled-squad",
)
def annotate(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str:
sentences = self.split_sentences(response)
annotated_response = ""
for sentence in sentences:
scores = []
for chunk in reranked_chunks:
score = self.source_annotator(sentence, chunk.content)
score["filename"] = chunk.metadata["filename"].split(".pdf")[0]
score["chunk_id"] = chunk.chunk_id
scores.append(score)
# Could also use a score cut-off instead of max()
max_score = max(scores, key=lambda x: x["score"])
annotated_response += f"{sentence} [filename: {max_score['filename']}, chunk_id: {max_score['chunk_id']} "
return annotated_response
def split_sentences(self, text: str) -> list[str]:
pattern = r"(?<=[.!?])\s+(?=[A-Z])"
sentences = re.split(pattern, text)
return [s.strip() for s in sentences if s.strip()]
|