File size: 4,320 Bytes
dbaa71b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
from typing import List, Optional, Any
import uuid

import nltk
from nltk import sent_tokenize
from pydantic import BaseModel

from obsei.payload import TextPayload
from obsei.preprocessor.base_preprocessor import (
    BaseTextPreprocessor,
    BaseTextProcessorConfig,
)

logger = logging.getLogger(__name__)


class TextSplitterPayload(BaseModel):
    phrase: str
    chunk_id: int
    chunk_length: int
    document_id: str
    total_chunks: Optional[int] = None


class TextSplitterConfig(BaseTextProcessorConfig):
    max_split_length: int = 512
    split_stride: int = 0  # overlap length
    document_id_key: Optional[str] = None  # document_id in meta
    enable_sentence_split: bool = False
    honor_paragraph_boundary: bool = False
    paragraph_marker: str = '\n\n'
    sentence_tokenizer: str = 'tokenizers/punkt/PY3/english.pickle'

    def __init__(self, **data: Any):
        super().__init__(**data)

        if self.enable_sentence_split:
            nltk.download('punkt')


class TextSplitter(BaseTextPreprocessor):
    def preprocess_input(  # type: ignore[override]
        self, input_list: List[TextPayload], config: TextSplitterConfig, **kwargs: Any
    ) -> List[TextPayload]:
        text_splits: List[TextPayload] = []

        for idx, input_data in enumerate(input_list):
            if (
                config.document_id_key
                and input_data.meta
                and config.document_id_key in input_data.meta
            ):
                document_id = str(input_data.meta.get(config.document_id_key))
            else:
                document_id = uuid.uuid4().hex

            if config.honor_paragraph_boundary:
                paragraphs = input_data.processed_text.split(config.paragraph_marker)
            else:
                paragraphs = [input_data.processed_text]

            atomic_texts: List[str] = []
            for paragraph in paragraphs:
                if config.enable_sentence_split:
                    atomic_texts.extend(sent_tokenize(paragraph))
                else:
                    atomic_texts.append(paragraph)

            split_id = 0
            document_splits: List[TextSplitterPayload] = []
            for text in atomic_texts:
                text_length = len(text)
                if text_length == 0:
                    continue

                start_idx = 0
                while start_idx < text_length:
                    if config.split_stride > 0 and start_idx > 0:
                        start_idx = (
                            self._valid_index(
                                text, start_idx - config.split_stride
                            )
                            + 1
                        )
                    end_idx = self._valid_index(
                        text,
                        min(start_idx + config.max_split_length, text_length),
                    )

                    phrase = text[start_idx:end_idx]
                    document_splits.append(
                        TextSplitterPayload(
                            phrase=phrase,
                            chunk_id=split_id,
                            chunk_length=len(phrase),
                            document_id=document_id,
                        )
                    )
                    start_idx = end_idx + 1
                    split_id += 1

            total_splits = len(document_splits)
            for split in document_splits:
                split.total_chunks = total_splits
                payload = TextPayload(
                    processed_text=split.phrase,
                    source_name=input_data.source_name,
                    segmented_data=input_data.segmented_data,
                    meta={**input_data.meta, **{"splitter": split}}
                    if input_data.meta
                    else {"splitter": split},
                )
                text_splits.append(payload)

        return text_splits

    @staticmethod
    def _valid_index(document: str, idx: int) -> int:
        if idx <= 0:
            return 0
        if idx >= len(document):
            return len(document)
        new_idx = idx
        while new_idx > 0:
            if document[new_idx] in [" ", "\n", "\t"]:
                break
            new_idx -= 1
        return new_idx