File size: 1,403 Bytes
cc3f1e1
 
 
 
 
 
 
 
 
 
 
 
 
bccb279
cc3f1e1
 
 
 
 
 
 
 
 
 
 
c56c0fd
cc3f1e1
 
 
 
 
 
 
b7b8732
cc3f1e1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
from typing import Any
import re
from loguru import logger

from rag_demo.preprocessing.embed import EmbeddedChunk


from transformers import pipeline


class SourceAnnotator:
    def __init__(self):
        # Extractive question answering model
        self.source_annotator = pipeline(
            "question-answering",
            model="distilbert/distilbert-base-cased-distilled-squad",
        )

    def annotate(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str:
        sentences = self.split_sentences(response)
        annotated_response = ""
        for sentence in sentences:
            scores = []
            for chunk in reranked_chunks:
                score = self.source_annotator(sentence, chunk.content)
                score["filename"] = chunk.metadata["filename"].split(".pdf")[0]
                score["chunk_id"] = chunk.chunk_id
                scores.append(score)

            # Could also use a score cut-off instead of max()
            max_score = max(scores, key=lambda x: x["score"])

            annotated_response += f"{sentence} [filename: {max_score['filename']}, chunk_id: {max_score['chunk_id']} "

        return annotated_response

    def split_sentences(self, text: str) -> list[str]:
        pattern = r"(?<=[.!?])\s+(?=[A-Z])"
        sentences = re.split(pattern, text)
        return [s.strip() for s in sentences if s.strip()]