Spaces:
Sleeping
Sleeping
import torch | |
class DocumentState: | |
def __init__(self): | |
self.sentence_end = [] | |
self.token_end = [] | |
self.orig_tokens = [] | |
self.tokens = [] | |
self.subtokens = [] | |
self.segments = [] | |
self.subtoken_map = [] | |
self.segment_subtoken_map = [] | |
self.sentence_map = [] | |
self.tensorized_sent = [] | |
self.sent_len_list = [] | |
def finalize(self): | |
subtoken_map = flatten(self.segment_subtoken_map) | |
num_words = len(flatten(self.segments)) | |
assert num_words == len(subtoken_map), (num_words, len(subtoken_map)) | |
return { | |
"orig_tokens": self.orig_tokens, | |
"sentences": self.segments, | |
"sent_len_list": self.sent_len_list, | |
"tensorized_sent": self.tensorized_sent, | |
"sentence_map": torch.tensor( | |
get_sentence_map(self.segments, self.sentence_end) | |
), | |
"subtoken_map": subtoken_map, | |
} | |
def get_sentence_map(segments, sentence_end): | |
current = 0 | |
sent_map = [] | |
sent_end_idx = 0 | |
assert len(sentence_end) == sum([len(s) for s in segments]) | |
for segment in segments: | |
for i in range(len(segment)): | |
sent_map.append(current) | |
current += int(sentence_end[sent_end_idx]) | |
sent_end_idx += 1 | |
return sent_map | |
def split_into_segments(document_state, max_segment_len, constraints1, constraints2): | |
current = 0 | |
while current < len(document_state.subtokens): | |
end = min(current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1) | |
while end >= current and not constraints1[end]: | |
end -= 1 | |
if end < current: | |
end = min( | |
current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1 | |
) | |
while end >= current and not constraints2[end]: | |
end -= 1 | |
if end < current: | |
raise Exception("Can't find valid segment") | |
document_state.segments.append(document_state.subtokens[current : end + 1]) | |
subtoken_map = document_state.subtoken_map[current : end + 1] | |
document_state.segment_subtoken_map.append(subtoken_map) | |
if hasattr(document_state, "info"): | |
info = document_state.info[current : end + 1] | |
document_state.segment_info.append(info) | |
current = end + 1 | |
def flatten(l): | |
return [item for sublist in l for item in sublist] | |
def get_tokenized_doc(doc, subword_tokenizer): | |
document_state = DocumentState() | |
word_idx = -1 | |
for sentence in doc: | |
for word in sentence: | |
document_state.orig_tokens.append(word) | |
subtokens = subword_tokenizer.convert_tokens_to_ids( | |
subword_tokenizer.tokenize(" " + word) | |
) | |
document_state.tokens.append(word) | |
document_state.token_end += ([False] * (len(subtokens) - 1)) + [True] | |
word_idx += 1 | |
for sidx, subtoken in enumerate(subtokens): | |
document_state.subtokens.append(subtoken) | |
document_state.sentence_end.append(False) | |
document_state.subtoken_map.append(word_idx) | |
document_state.sentence_end[-1] = True | |
return document_state | |
def basic_tokenize_doc(doc_str, basic_tokenizer): | |
doc = [] | |
for sent in basic_tokenizer(doc_str).sents: | |
wordlist = [str(word) for word in sent] | |
doc.append(wordlist) | |
return doc | |
def tokenize_and_segment_doc( | |
basic_tokenized_doc, subword_tokenizer, max_segment_len=4096 | |
): | |
document_state: DocumentState = get_tokenized_doc( | |
basic_tokenized_doc, subword_tokenizer | |
) | |
document = post_tokenization_processing( | |
document_state, subword_tokenizer, max_segment_len=max_segment_len | |
) | |
return document | |
def post_tokenization_processing( | |
document_state: DocumentState, subword_tokenizer, max_segment_len=4096 | |
): | |
split_into_segments( | |
document_state, | |
max_segment_len, | |
document_state.sentence_end, | |
document_state.token_end, | |
) | |
sent_len_list = [len(sent) for sent in document_state.segments] | |
document_state.sent_len_list = sent_len_list | |
document_state.segments_indices = document_state.segments | |
# # Tensorize sentence - Streaming coreference is done one window at a time, so no padding is required | |
tensorized_sent = [ | |
torch.unsqueeze( | |
torch.tensor( | |
[subword_tokenizer.cls_token_id] | |
+ sent | |
+ [subword_tokenizer.sep_token_id] | |
), | |
dim=0, | |
) | |
for sent in document_state.segments | |
] | |
document_state.tensorized_sent = tensorized_sent | |
return document_state.finalize() | |
if __name__ == "__main__": | |
from transformers import LongformerTokenizerFast | |
tokenizer = LongformerTokenizerFast.from_pretrained( | |
"allenai/longformer-large-4096", | |
add_prefix_space=True, | |
clean_up_tokenization_spaces=True, | |
) | |
sample_doc_str = "My father’s eyes had closed upon the light of this world six months, when Ishmael opened on it." | |
print(get_tokenized_doc(sample_doc_str, tokenizer)) | |