File size: 10,411 Bytes
d74e15b
 
 
 
 
689fec3
 
 
d74e15b
 
 
 
 
 
 
 
689fec3
 
d74e15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689fec3
 
 
 
d74e15b
 
 
 
 
 
 
 
 
e24467d
d74e15b
689fec3
 
 
 
 
 
 
 
 
 
d74e15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689fec3
 
 
 
 
d74e15b
 
 
 
 
 
 
 
 
 
 
 
689fec3
 
 
d74e15b
 
 
689fec3
 
 
 
d74e15b
 
 
 
689fec3
 
 
 
 
 
 
 
 
 
 
 
 
 
d74e15b
 
689fec3
d74e15b
 
 
 
 
 
 
 
 
 
689fec3
 
 
 
d74e15b
689fec3
 
 
 
 
 
 
d74e15b
689fec3
d74e15b
 
 
 
 
 
 
 
 
 
 
 
 
 
689fec3
d74e15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e24467d
d74e15b
 
 
e24467d
689fec3
73bcc92
689fec3
 
 
 
 
 
d74e15b
689fec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73bcc92
689fec3
d74e15b
e086f23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import numpy as np
import pandas as pd
from multitest import MultiTest
from tqdm import tqdm
import logging
import json
import re
GAMMA = 0.45


def truncae_to_max_no_tokens(text, max_no_tokens):
    return " ".join(text.split()[:max_no_tokens])


class DetectLM(object):
    def __init__(self, sentence_detection_function, survival_function_per_length,
                 min_len=1, max_len=100, HC_type="stbl", gamma=GAMMA,
                 length_limit_policy='truncate', ignore_first_sentence=False, cache_logloss_path=''):
        """
        Test for the presence of sentences of irregular origin as reflected by the
        sentence_detection_function. The test is based on the sentence detection function
        and the P-values obtained from the survival function of the detector's responses.

        Args:
        ----
            :sentence_detection_function:  a function returning the response of the text 
            under the detector. Typically, the response is a logloss value under some language model.
            :survival_function_per_length:  survival_function_per_length(l, x) is the probability of the language
            model to produce a sentence value as extreme as x or more when the sentence s is the input to
            the detector. The function is defined for every sentence length l.
            The detector can also recieve a context c, in which case the input is the pair (s, c).
            :length_limit_policy: When a sentence exceeds ``max_len``, we can:
                'truncate':  truncate sentence to the maximal length :max_len
                'ignore':  do not evaluate the response and P-value for this sentence
                'max_available':  use the logloss function of the maximal available length
            :ignore_first_sentence:  whether to ignore the first sentence in the document or not. Useful when assuming
            that the first sentence is a title or a header or a context of the form previous sentence.
            :HC_type:  'stbl' True for the 2008 HC version, otherwise uses the 2004 version.
            :gamma:  the gamma parameter of the HC test.
            :cache_logloss_path: cache dict to restore the logloss faster
        """

        self.survival_function_per_length = survival_function_per_length
        self.sentence_detector = sentence_detection_function
        self.min_len = min_len
        self.max_len = max_len
        self.length_limit_policy = length_limit_policy
        self.ignore_first_sentence = ignore_first_sentence
        self.HC_stbl = True if HC_type == 'stbl' else False
        self.gamma = gamma

        # Idan 26/05/204
        self.cache_logloss_path = cache_logloss_path
        try:
            # Load the dictionary from the file
            with open(self.cache_logloss_path, 'r') as file:
                self.cache_logloss = json.load(file)
        except:
            print('Could not find cache file')
            self.cache_logloss = None

    def _logperp(self, sent: str, context=None) -> float:
        return float(self.sentence_detector(sent, context))

    def _test_sentence(self, sentence: str, context=None):
        return self._logperp(sentence, context)
    
    def _get_length(self, sentence: str):
        return len(sentence.split())

    def _test_response(self, response: float, length: int):
        """
        Args:
            response:  sentence logloss
            length:    sentence length in tokens

        Returns:
          pvals:    P-value of the logloss of the sentence
          comments: comment on the P-value
        """
        if self.min_len <= length:
            comment = "OK"
            if length > self.max_len:  # in case length exceeds specifications...
                if self.length_limit_policy == 'truncate':
                    length = self.max_len
                    comment = f"truncated to {self.max_len} tokens"
                elif self.length_limit_policy == 'ignore':
                    comment = "ignored (above maximum limit)"
                    return np.nan, np.nan, comment
                elif self.length_limit_policy == 'max_available':
                    comment = "exceeding length limit; resorting to max-available length"
                    length = self.max_len
            pval = self.survival_function_per_length(length, response)
            try:
                assert pval >= 0, "Negative P-value. Something is wrong."
            except:
                import pdb; pdb.set_trace()

            return dict(response=response, 
                        pvalue=pval, 
                        length=length,
                        comment=comment)
        else:
            comment = "ignored (below minimal length)"
            return dict(response=response, 
                        pvalue=np.nan, 
                        length=length,
                        comment=comment)

    def _get_pvals(self, responses: list, lengths: list) -> tuple:
        """
        Pvalues from responses and lengths
        """
        pvals = []
        comments = []
        for response, length in zip(responses, lengths):
            if not np.isnan(response):
                r = self._test_response(response, length)
            else:
                r = dict(response=response, pvalue=np.nan, length=length, comment="ignored (no response)")
            pvals.append(float(r['pvalue']))
            comments.append(r['comment'])
        return pvals, comments

    def clean_string(self, s):
        # Remove escape characters
        s = re.sub(r'\\[nrt]', '', s)
        # Strip leading and trailing spaces and quotes
        s = s.strip().strip("'")
        # Convert to lower case
        return s.lower()
    
    def _get_logloss_cache(self, sent: str) -> float:
        sent = sent.strip()
        if self.cache_logloss is None: return None
        if sent not in self.cache_logloss: return None
        return self.cache_logloss[sent]
    
    def _get_responses(self, sentences: list, contexts: list) -> list:
        """
        Compute response and length of a every sentence in a list
        """
        assert len(sentences) == len(contexts)

        responses = []
        lengths = []
        for sent, ctx in tqdm(zip(sentences, contexts)):
            logging.debug(f"Testing sentence: {sent} | context: {ctx}")
            length = self._get_length(sent)
            if self.length_limit_policy == 'truncate':
                sent = truncae_to_max_no_tokens(sent, self.max_len)
            # if length == 1:
            #     logging.warning(f"Sentence {sent} is too short. Skipping.")
            #     responses.append(np.nan)
            #     continue
            try:
                # Try getting logloss from cache
                sentence_response = self._get_logloss_cache(self.clean_string(sent))
                if sentence_response != None:
                    responses.append(sentence_response)
                else: # If sentence not found
                    current_response = self._test_sentence(sent, ctx)
                    responses.append(current_response)
            except:
                # something unusual has happened...
                import pdb; pdb.set_trace()
            lengths.append(length)
        return responses, lengths

    def get_pvals(self, sentences: list, contexts: list) -> tuple:
        """
        logloss test of every (sentence, context) pair
        """
        assert len(sentences) == len(contexts)

        responses, lengths = self._get_responses(sentences, contexts)
        pvals, comments = self._get_pvals(responses, lengths)
        return pvals, responses, comments

    def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> (MultiTest, pd.DataFrame):
        pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
        if self.ignore_first_sentence:
            pvals[0] = np.nan
            logging.info('Ignoring the first sentence.')
            comments[0] = "ignored (first sentence)"
        
        df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals,
                           'context': lo_contexts, 'comment': comments},
                          index=range(len(lo_chunks)))
        df_test = df[~df.pvalue.isna()]
        if df_test.empty:
            logging.warning('No valid chunks to test.')
            return None, df
        return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df

    def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
        mt, df = self._test_chunked_doc(lo_chunks, lo_contexts)
        if mt is None:
            hc = np.nan
            fisher = (np.nan, np.nan)
            df['mask'] = pd.NA
        else:
            hc, hct = mt.hc(gamma=self.gamma)
            fisher = mt.fisher()
            df['mask'] = df['pvalue'] <= hct
        if dashboard:
            mt.hc_dashboard(gamma=self.gamma)
        
        dc = dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1], minP=mt.minp())
        return dc
    
    def from_responses(self, responses: list, lengths: list, dashboard=False) -> dict:
        """
        Compute P-values from responses and lengths
        """

        pvals, comments = self._get_pvals(responses, lengths)
        if self.ignore_first_sentence:
            pvals[0] = np.nan
            logging.info('Ignoring the first sentence.')
            comments[0] = "ignored (first sentence)"
        
        df = pd.DataFrame({'response': responses, 'pvalue': pvals, 'comment': comments},
                          index=range(len(responses)))
        df_test = df[~df.pvalue.isna()]
        if df_test.empty:
            logging.warning('No valid chunks to test.')
            return None, df
        mt = MultiTest(df_test.pvalue, stbl=self.HC_stbl)
        
        if mt is None:
            hc = np.nan
            fisher = (np.nan, np.nan)
            df['mask'] = pd.NA
        else:
            hc, hct = mt.hc(gamma=self.gamma)
            fisher = mt.fisher()
            df['mask'] = df['pvalue'] <= hct
        if dashboard:
            mt.hc_dashboard(gamma=self.gamma)
        return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])
    
    def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
        return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard)