Spaces:
Sleeping
Sleeping
import re | |
from functools import partial | |
import nltk | |
def get_len(tokenizer, text): | |
return len(tokenizer.encode(text, add_special_tokens=False)) | |
class Truncater: | |
def __init__(self, tokenizer, *, max_length): | |
self.max_length = max_length | |
self.tokenizer = tokenizer | |
def __call__(self, text): | |
return self.truncate(text) | |
def truncate(self, text): | |
input_ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=self.max_length) | |
return self.tokenizer.decode(input_ids) | |
class Refiner: | |
def __init__(self, tokenizer, *, chunk_size, max_chunk_size): | |
assert chunk_size <= max_chunk_size | |
self.chunk_size = chunk_size | |
self.max_chunk_size = max_chunk_size | |
self.tokenizer = tokenizer | |
self.get_len = partial(get_len, tokenizer) | |
self.current_summary = None | |
self.chunks = [] | |
self.initial_prompt = "" | |
self.chunk_prefix = "" | |
self.summary_prefix = "" | |
self.refinement_prompt = "" | |
def set_prompts(self, *, initial_prompt="", chunk_prefix="", summary_prefix="", refinement_prompt=""): | |
self.initial_prompt = initial_prompt | |
self.chunk_prefix = chunk_prefix | |
self.summary_prefix = summary_prefix | |
self.refinement_prompt = refinement_prompt | |
def current_prompt(self): | |
if self.current_summary is None: | |
return self.initial_prompt | |
else: | |
return self.refinement_prompt | |
def __call__(self, text): | |
self.chunks = Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len) | |
return self.refine(text) | |
def __len__(self): | |
return len(self.chunks) | |
def refine(self, text): | |
for chunk in self.chunks: | |
if self.current_summary is None: | |
yield chunk | |
else: | |
summary = self.summary_prefix + self.current_summary | |
chunk = self.chunk_prefix + chunk | |
yield summary + "\n\n" + chunk | |
def set_current_summary(self, summary): | |
self.current_summary = summary | |
class Chunker: | |
def __init__(self, tokenizer, *, chunk_size, max_chunk_size): | |
assert chunk_size <= max_chunk_size | |
self.chunk_size = chunk_size # target chunk size | |
self.max_chunk_size = max_chunk_size # hard limit | |
self.tokenizer = tokenizer | |
self.get_len = partial(get_len, tokenizer) | |
def __call__(self, text): | |
return Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len) | |
def chunk_text(text, chunk_size, max_chunk_size, len_fn): | |
paragraphs = re.split("\n\n|\n(?=[^\n])", text) | |
text = " ".join(paragraphs) | |
sentences = nltk.sent_tokenize(text) | |
sentences = [s.strip() for s in sentences] | |
chunks = [] | |
Chunker._chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn) | |
return chunks | |
def _chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn): | |
if not sentences: | |
return | |
remaining_text = " ".join(sentences) | |
if len_fn(remaining_text) <= max_chunk_size: | |
chunks.append(remaining_text) | |
return | |
index = 0 | |
length_so_far = 0 | |
while index < len(sentences) and length_so_far + len_fn(sentences[index]) <= chunk_size: | |
length_so_far += len_fn(sentences[index]) | |
index += 1 | |
if index == 0: | |
raise ValueError("No chunking possible") | |
else: | |
chunk = " ".join(sentences[:index]) | |
chunks.append(chunk) | |
Chunker._chunk_text(sentences[index:], chunks, chunk_size, max_chunk_size, len_fn) | |