File size: 3,890 Bytes
cb71ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import re
from functools import partial

import nltk


def get_len(tokenizer, text):
    return len(tokenizer.encode(text, add_special_tokens=False))

class Truncater:
    def __init__(self, tokenizer, *, max_length):
        self.max_length = max_length
        self.tokenizer = tokenizer

    def __call__(self, text):
        return self.truncate(text)

    def truncate(self, text):
        input_ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=self.max_length)
        return self.tokenizer.decode(input_ids)

class Refiner:
    def __init__(self, tokenizer, *, chunk_size, max_chunk_size):
        assert chunk_size <= max_chunk_size

        self.chunk_size = chunk_size
        self.max_chunk_size = max_chunk_size

        self.tokenizer = tokenizer
        self.get_len = partial(get_len, tokenizer)

        self.current_summary = None
        self.chunks = []

        self.initial_prompt = ""
        self.chunk_prefix = ""
        self.summary_prefix = ""
        self.refinement_prompt = ""

    def set_prompts(self, *, initial_prompt="", chunk_prefix="", summary_prefix="", refinement_prompt=""):
        self.initial_prompt = initial_prompt
        self.chunk_prefix = chunk_prefix
        self.summary_prefix = summary_prefix
        self.refinement_prompt = refinement_prompt

    @property
    def current_prompt(self):
        if self.current_summary is None:
            return self.initial_prompt
        else:
            return self.refinement_prompt

    def __call__(self, text):
        self.chunks = Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len)
        return self.refine(text)

    def __len__(self):
        return len(self.chunks)

    def refine(self, text):
        for chunk in self.chunks:
            if self.current_summary is None:
                yield chunk
            else:
                summary = self.summary_prefix + self.current_summary
                chunk = self.chunk_prefix + chunk
                yield summary + "\n\n" + chunk
    
    def set_current_summary(self, summary):
        self.current_summary = summary

class Chunker:
    def __init__(self, tokenizer, *, chunk_size, max_chunk_size):
        assert chunk_size <= max_chunk_size

        self.chunk_size = chunk_size # target chunk size
        self.max_chunk_size = max_chunk_size # hard limit
        self.tokenizer = tokenizer
        self.get_len = partial(get_len, tokenizer)

    def __call__(self, text):
        return Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len)

    @staticmethod
    def chunk_text(text, chunk_size, max_chunk_size, len_fn):
        paragraphs = re.split("\n\n|\n(?=[^\n])", text)
        text = " ".join(paragraphs)
        sentences = nltk.sent_tokenize(text)
        sentences = [s.strip() for s in sentences]
        chunks = []
        Chunker._chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn)
        return chunks

    @staticmethod
    def _chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn):
        if not sentences:
            return

        remaining_text = " ".join(sentences)
        if len_fn(remaining_text) <= max_chunk_size:
            chunks.append(remaining_text)
            return

        index = 0
        length_so_far = 0
        while index < len(sentences) and length_so_far + len_fn(sentences[index]) <= chunk_size:
            length_so_far += len_fn(sentences[index])
            index += 1

        if index == 0:
            raise ValueError("No chunking possible")
        else:
            chunk = " ".join(sentences[:index])
            chunks.append(chunk)
            Chunker._chunk_text(sentences[index:], chunks, chunk_size, max_chunk_size, len_fn)