Spaces:
Sleeping
Sleeping
File size: 4,320 Bytes
dbaa71b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import logging
from typing import List, Optional, Any
import uuid
import nltk
from nltk import sent_tokenize
from pydantic import BaseModel
from obsei.payload import TextPayload
from obsei.preprocessor.base_preprocessor import (
BaseTextPreprocessor,
BaseTextProcessorConfig,
)
logger = logging.getLogger(__name__)
class TextSplitterPayload(BaseModel):
phrase: str
chunk_id: int
chunk_length: int
document_id: str
total_chunks: Optional[int] = None
class TextSplitterConfig(BaseTextProcessorConfig):
max_split_length: int = 512
split_stride: int = 0 # overlap length
document_id_key: Optional[str] = None # document_id in meta
enable_sentence_split: bool = False
honor_paragraph_boundary: bool = False
paragraph_marker: str = '\n\n'
sentence_tokenizer: str = 'tokenizers/punkt/PY3/english.pickle'
def __init__(self, **data: Any):
super().__init__(**data)
if self.enable_sentence_split:
nltk.download('punkt')
class TextSplitter(BaseTextPreprocessor):
def preprocess_input( # type: ignore[override]
self, input_list: List[TextPayload], config: TextSplitterConfig, **kwargs: Any
) -> List[TextPayload]:
text_splits: List[TextPayload] = []
for idx, input_data in enumerate(input_list):
if (
config.document_id_key
and input_data.meta
and config.document_id_key in input_data.meta
):
document_id = str(input_data.meta.get(config.document_id_key))
else:
document_id = uuid.uuid4().hex
if config.honor_paragraph_boundary:
paragraphs = input_data.processed_text.split(config.paragraph_marker)
else:
paragraphs = [input_data.processed_text]
atomic_texts: List[str] = []
for paragraph in paragraphs:
if config.enable_sentence_split:
atomic_texts.extend(sent_tokenize(paragraph))
else:
atomic_texts.append(paragraph)
split_id = 0
document_splits: List[TextSplitterPayload] = []
for text in atomic_texts:
text_length = len(text)
if text_length == 0:
continue
start_idx = 0
while start_idx < text_length:
if config.split_stride > 0 and start_idx > 0:
start_idx = (
self._valid_index(
text, start_idx - config.split_stride
)
+ 1
)
end_idx = self._valid_index(
text,
min(start_idx + config.max_split_length, text_length),
)
phrase = text[start_idx:end_idx]
document_splits.append(
TextSplitterPayload(
phrase=phrase,
chunk_id=split_id,
chunk_length=len(phrase),
document_id=document_id,
)
)
start_idx = end_idx + 1
split_id += 1
total_splits = len(document_splits)
for split in document_splits:
split.total_chunks = total_splits
payload = TextPayload(
processed_text=split.phrase,
source_name=input_data.source_name,
segmented_data=input_data.segmented_data,
meta={**input_data.meta, **{"splitter": split}}
if input_data.meta
else {"splitter": split},
)
text_splits.append(payload)
return text_splits
@staticmethod
def _valid_index(document: str, idx: int) -> int:
if idx <= 0:
return 0
if idx >= len(document):
return len(document)
new_idx = idx
while new_idx > 0:
if document[new_idx] in [" ", "\n", "\t"]:
break
new_idx -= 1
return new_idx
|