retkowski's picture
Add demo
cb71ef5
import re
from functools import partial
import nltk
def get_len(tokenizer, text):
return len(tokenizer.encode(text, add_special_tokens=False))
class Truncater:
def __init__(self, tokenizer, *, max_length):
self.max_length = max_length
self.tokenizer = tokenizer
def __call__(self, text):
return self.truncate(text)
def truncate(self, text):
input_ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=self.max_length)
return self.tokenizer.decode(input_ids)
class Refiner:
def __init__(self, tokenizer, *, chunk_size, max_chunk_size):
assert chunk_size <= max_chunk_size
self.chunk_size = chunk_size
self.max_chunk_size = max_chunk_size
self.tokenizer = tokenizer
self.get_len = partial(get_len, tokenizer)
self.current_summary = None
self.chunks = []
self.initial_prompt = ""
self.chunk_prefix = ""
self.summary_prefix = ""
self.refinement_prompt = ""
def set_prompts(self, *, initial_prompt="", chunk_prefix="", summary_prefix="", refinement_prompt=""):
self.initial_prompt = initial_prompt
self.chunk_prefix = chunk_prefix
self.summary_prefix = summary_prefix
self.refinement_prompt = refinement_prompt
@property
def current_prompt(self):
if self.current_summary is None:
return self.initial_prompt
else:
return self.refinement_prompt
def __call__(self, text):
self.chunks = Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len)
return self.refine(text)
def __len__(self):
return len(self.chunks)
def refine(self, text):
for chunk in self.chunks:
if self.current_summary is None:
yield chunk
else:
summary = self.summary_prefix + self.current_summary
chunk = self.chunk_prefix + chunk
yield summary + "\n\n" + chunk
def set_current_summary(self, summary):
self.current_summary = summary
class Chunker:
def __init__(self, tokenizer, *, chunk_size, max_chunk_size):
assert chunk_size <= max_chunk_size
self.chunk_size = chunk_size # target chunk size
self.max_chunk_size = max_chunk_size # hard limit
self.tokenizer = tokenizer
self.get_len = partial(get_len, tokenizer)
def __call__(self, text):
return Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len)
@staticmethod
def chunk_text(text, chunk_size, max_chunk_size, len_fn):
paragraphs = re.split("\n\n|\n(?=[^\n])", text)
text = " ".join(paragraphs)
sentences = nltk.sent_tokenize(text)
sentences = [s.strip() for s in sentences]
chunks = []
Chunker._chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn)
return chunks
@staticmethod
def _chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn):
if not sentences:
return
remaining_text = " ".join(sentences)
if len_fn(remaining_text) <= max_chunk_size:
chunks.append(remaining_text)
return
index = 0
length_so_far = 0
while index < len(sentences) and length_so_far + len_fn(sentences[index]) <= chunk_size:
length_so_far += len_fn(sentences[index])
index += 1
if index == 0:
raise ValueError("No chunking possible")
else:
chunk = " ".join(sentences[:index])
chunks.append(chunk)
Chunker._chunk_text(sentences[index:], chunks, chunk_size, max_chunk_size, len_fn)