Spaces:
Sleeping
Sleeping
import logging | |
from typing import List, Optional, Any | |
import uuid | |
import nltk | |
from nltk import sent_tokenize | |
from pydantic import BaseModel | |
from obsei.payload import TextPayload | |
from obsei.preprocessor.base_preprocessor import ( | |
BaseTextPreprocessor, | |
BaseTextProcessorConfig, | |
) | |
logger = logging.getLogger(__name__) | |
class TextSplitterPayload(BaseModel): | |
phrase: str | |
chunk_id: int | |
chunk_length: int | |
document_id: str | |
total_chunks: Optional[int] = None | |
class TextSplitterConfig(BaseTextProcessorConfig): | |
max_split_length: int = 512 | |
split_stride: int = 0 # overlap length | |
document_id_key: Optional[str] = None # document_id in meta | |
enable_sentence_split: bool = False | |
honor_paragraph_boundary: bool = False | |
paragraph_marker: str = '\n\n' | |
sentence_tokenizer: str = 'tokenizers/punkt/PY3/english.pickle' | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
if self.enable_sentence_split: | |
nltk.download('punkt') | |
class TextSplitter(BaseTextPreprocessor): | |
def preprocess_input( # type: ignore[override] | |
self, input_list: List[TextPayload], config: TextSplitterConfig, **kwargs: Any | |
) -> List[TextPayload]: | |
text_splits: List[TextPayload] = [] | |
for idx, input_data in enumerate(input_list): | |
if ( | |
config.document_id_key | |
and input_data.meta | |
and config.document_id_key in input_data.meta | |
): | |
document_id = str(input_data.meta.get(config.document_id_key)) | |
else: | |
document_id = uuid.uuid4().hex | |
if config.honor_paragraph_boundary: | |
paragraphs = input_data.processed_text.split(config.paragraph_marker) | |
else: | |
paragraphs = [input_data.processed_text] | |
atomic_texts: List[str] = [] | |
for paragraph in paragraphs: | |
if config.enable_sentence_split: | |
atomic_texts.extend(sent_tokenize(paragraph)) | |
else: | |
atomic_texts.append(paragraph) | |
split_id = 0 | |
document_splits: List[TextSplitterPayload] = [] | |
for text in atomic_texts: | |
text_length = len(text) | |
if text_length == 0: | |
continue | |
start_idx = 0 | |
while start_idx < text_length: | |
if config.split_stride > 0 and start_idx > 0: | |
start_idx = ( | |
self._valid_index( | |
text, start_idx - config.split_stride | |
) | |
+ 1 | |
) | |
end_idx = self._valid_index( | |
text, | |
min(start_idx + config.max_split_length, text_length), | |
) | |
phrase = text[start_idx:end_idx] | |
document_splits.append( | |
TextSplitterPayload( | |
phrase=phrase, | |
chunk_id=split_id, | |
chunk_length=len(phrase), | |
document_id=document_id, | |
) | |
) | |
start_idx = end_idx + 1 | |
split_id += 1 | |
total_splits = len(document_splits) | |
for split in document_splits: | |
split.total_chunks = total_splits | |
payload = TextPayload( | |
processed_text=split.phrase, | |
source_name=input_data.source_name, | |
segmented_data=input_data.segmented_data, | |
meta={**input_data.meta, **{"splitter": split}} | |
if input_data.meta | |
else {"splitter": split}, | |
) | |
text_splits.append(payload) | |
return text_splits | |
def _valid_index(document: str, idx: int) -> int: | |
if idx <= 0: | |
return 0 | |
if idx >= len(document): | |
return len(document) | |
new_idx = idx | |
while new_idx > 0: | |
if document[new_idx] in [" ", "\n", "\t"]: | |
break | |
new_idx -= 1 | |
return new_idx | |