File size: 3,655 Bytes
e086f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pandas as pd
import numpy as np
import json

class PrepareArticles(object):
    """
    Parse preprocessed data from csv

    This information is needed for evaluating log-perplexity of the text with respect to a language model
    and later on to test the likelihood that the sentence was sampled from the model with the relevant context.
    """
    def __init__(self, article_obj, get_edits=False, min_tokens=10, max_tokens=100, max_sentences=None):
        self.article_obj = article_obj
        self.min_tokens = min_tokens
        self.max_tokens = max_tokens
        self.get_edits = get_edits
        self.max_sentences = max_sentences

    def __call__(self, combined=True):
        return self.parse_dataset(combined)
    
    def parse_dataset(self, combined=True):

        texts = []
        lengths = []
        contexts = []
        tags = []
        
        current_texts = []
        current_lengths = []
        current_contexts = []
        current_tags = []
        exceeded_max_sentences = False
        
        for sub_title in self.article_obj['sub_titles']: # For each sub title
            for sentence in sub_title['sentences']: # Go over each sentence
                sentence_size = len(sentence['sentence'].split())
                if sentence_size >= self.min_tokens and sentence_size <= self.max_tokens:
                    current_texts.append(sentence['sentence'])
                    current_lengths.append(len(sentence['sentence'].split())) # Number of tokens
                    current_contexts.append(sentence['context'] if 'context' in sentence else None)
                    current_tags.append('no edits')

                # If get_edits and has edited sentence save it
                if self.get_edits and 'alternative' in sentence and len(sentence['alternative'].split()) >= self.min_tokens and len(sentence['alternative'].split()) <= self.max_tokens:
                    current_texts.append(sentence['alternative'])
                    current_lengths.append(len(sentence['alternative'].split()))
                    current_contexts.append(sentence['alternative_context'] if 'alternative_context' in sentence else None)
                    current_tags.append('<edit>')
                if self.max_sentences and len(current_texts) >= self.max_sentences:
                    exceeded_max_sentences = True
                    break
                    # return {'text': np.array(texts, dtype=object), 'length': np.array(lengths, dtype=object), 'context': np.array(contexts, dtype=object), 'tag': np.array(tags, dtype=object),
                    #             'number_in_par': np.arange(1,1+len(texts))}
            if exceeded_max_sentences:
                break
        
        # If exceede max sentences only if self.max_sentences is not None
        if (self.max_sentences and exceeded_max_sentences) or (not self.max_sentences):
            # If combined, combine the data
            if combined:
                texts = texts + current_texts
                lengths = lengths + current_lengths
                contexts = contexts + current_contexts
                tags = tags + current_tags
            else:
                texts.append(np.array(current_texts))
                lengths.append(np.array(current_lengths))
                contexts.append(np.array(current_contexts))
                tags.append(np.array(current_tags))
            
        return {'text': np.array(texts, dtype=object), 'length': np.array(lengths, dtype=object), 'context': np.array(contexts, dtype=object), 'tag': np.array(tags, dtype=object),
                    'number_in_par': np.arange(1,1+len(texts))}