|
import os |
|
from typing import Any |
|
import re |
|
from loguru import logger |
|
|
|
from rag_demo.preprocessing.embed import EmbeddedChunk |
|
|
|
|
|
from transformers import pipeline |
|
|
|
|
|
class SourceAnnotator: |
|
def __init__(self): |
|
|
|
self.source_annotator = pipeline( |
|
"question-answering", |
|
model="distilbert/distilbert-base-cased-distilled-squad", |
|
) |
|
|
|
def annotate(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str: |
|
sentences = self.split_sentences(response) |
|
annotated_response = "" |
|
for sentence in sentences: |
|
scores = [] |
|
for chunk in reranked_chunks: |
|
score = self.source_annotator(sentence, chunk.content) |
|
score["filename"] = chunk.metadata["filename"].split(".pdf")[0] |
|
score["chunk_id"] = chunk.chunk_id |
|
scores.append(score) |
|
|
|
|
|
max_score = max(scores, key=lambda x: x["score"]) |
|
|
|
annotated_response += f"{sentence} [filename: {max_score['filename']}, chunk_id: {max_score['chunk_id']} " |
|
|
|
return annotated_response |
|
|
|
def split_sentences(self, text: str) -> list[str]: |
|
pattern = r"(?<=[.!?])\s+(?=[A-Z])" |
|
sentences = re.split(pattern, text) |
|
return [s.strip() for s in sentences if s.strip()] |
|
|