import re from functools import partial import nltk def get_len(tokenizer, text): return len(tokenizer.encode(text, add_special_tokens=False)) class Truncater: def __init__(self, tokenizer, *, max_length): self.max_length = max_length self.tokenizer = tokenizer def __call__(self, text): return self.truncate(text) def truncate(self, text): input_ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=self.max_length) return self.tokenizer.decode(input_ids) class Refiner: def __init__(self, tokenizer, *, chunk_size, max_chunk_size): assert chunk_size <= max_chunk_size self.chunk_size = chunk_size self.max_chunk_size = max_chunk_size self.tokenizer = tokenizer self.get_len = partial(get_len, tokenizer) self.current_summary = None self.chunks = [] self.initial_prompt = "" self.chunk_prefix = "" self.summary_prefix = "" self.refinement_prompt = "" def set_prompts(self, *, initial_prompt="", chunk_prefix="", summary_prefix="", refinement_prompt=""): self.initial_prompt = initial_prompt self.chunk_prefix = chunk_prefix self.summary_prefix = summary_prefix self.refinement_prompt = refinement_prompt @property def current_prompt(self): if self.current_summary is None: return self.initial_prompt else: return self.refinement_prompt def __call__(self, text): self.chunks = Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len) return self.refine(text) def __len__(self): return len(self.chunks) def refine(self, text): for chunk in self.chunks: if self.current_summary is None: yield chunk else: summary = self.summary_prefix + self.current_summary chunk = self.chunk_prefix + chunk yield summary + "\n\n" + chunk def set_current_summary(self, summary): self.current_summary = summary class Chunker: def __init__(self, tokenizer, *, chunk_size, max_chunk_size): assert chunk_size <= max_chunk_size self.chunk_size = chunk_size # target chunk size self.max_chunk_size = max_chunk_size # hard limit self.tokenizer = tokenizer self.get_len = partial(get_len, tokenizer) def __call__(self, text): return Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len) @staticmethod def chunk_text(text, chunk_size, max_chunk_size, len_fn): paragraphs = re.split("\n\n|\n(?=[^\n])", text) text = " ".join(paragraphs) sentences = nltk.sent_tokenize(text) sentences = [s.strip() for s in sentences] chunks = [] Chunker._chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn) return chunks @staticmethod def _chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn): if not sentences: return remaining_text = " ".join(sentences) if len_fn(remaining_text) <= max_chunk_size: chunks.append(remaining_text) return index = 0 length_so_far = 0 while index < len(sentences) and length_so_far + len_fn(sentences[index]) <= chunk_size: length_so_far += len_fn(sentences[index]) index += 1 if index == 0: raise ValueError("No chunking possible") else: chunk = " ".join(sentences[:index]) chunks.append(chunk) Chunker._chunk_text(sentences[index:], chunks, chunk_size, max_chunk_size, len_fn)