File size: 7,850 Bytes
7a0ff7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import numpy as np
import pandas as pd
from multitest import MultiTest
from tqdm import tqdm
import logging


def truncae_to_max_no_tokens(text, max_no_tokens):
    return " ".join(text.split()[:max_no_tokens])


class DetectLM(object):
    def __init__(self, sentence_detection_function, survival_function_per_length,

                 min_len=4, max_len=100, HC_type="stbl",

                 length_limit_policy='truncate', ignore_first_sentence=False):
        """

        Test for the presence of sentences of irregular origin as reflected by the

        sentence_detection_function. The test is based on the sentence detection function

        and the P-values obtained from the survival function of the detector's responses.



        Args:

        ----

            :sentence_detection_function:  a function returning the response of the text 

            under the detector. Typically, the response is a logloss value under some language model.

            :survival_function_per_length:  survival_function_per_length(l, x) is the probability of the language

            model to produce a sentence value as extreme as x or more when the sentence s is the input to

            the detector. The function is defined for every sentence length l.

            The detector can also recieve a context c, in which case the input is the pair (s, c).

            :length_limit_policy: When a sentence exceeds ``max_len``, we can:

                'truncate':  truncate sentence to the maximal length :max_len

                'ignore':  do not evaluate the response and P-value for this sentence

                'max_available':  use the logloss function of the maximal available length

            :ignore_first_sentence:  whether to ignore the first sentence in the document or not. Useful when assuming

        context of the form previous sentence.

        """

        self.survival_function_per_length = survival_function_per_length
        self.sentence_detector = sentence_detection_function
        self.min_len = min_len
        self.max_len = max_len
        self.length_limit_policy = length_limit_policy
        self.ignore_first_sentence = ignore_first_sentence
        self.HC_stbl = True if HC_type == 'stbl' else False

    def _logperp(self, sent: str, context=None) -> float:
        return float(self.sentence_detector(sent, context))

    def _test_sentence(self, sentence: str, context=None):
        return self._logperp(sentence, context)
    
    def _get_length(self, sentence: str):
        return len(sentence.split())

    def _test_response(self, response: float, length: int):
        """

        Args:

            response:  sentence logloss

            length:    sentence length in tokens



        Returns:

          pvals:    P-value of the logloss of the sentence

          comments: comment on the P-value

        """
        if self.min_len <= length:
            comment = "OK"
            if length > self.max_len:  # in case length exceeds specifications...
                if self.length_limit_policy == 'truncate':
                    length = self.max_len
                    comment = f"truncated to {self.max_len} tokens"
                elif self.length_limit_policy == 'ignore':
                    comment = "ignored (above maximum limit)"
                    return np.nan, np.nan, comment
                elif self.length_limit_policy == 'max_available':
                    comment = "exceeding length limit; resorting to max-available length"
                    length = self.max_len
            pval = self.survival_function_per_length(length, response)
            assert pval >= 0, "Negative P-value. Something is wrong."
            return dict(response=response, 
                        pvalue=pval, 
                        length=length,
                        comment=comment)
        else:
            comment = "ignored (below minimal length)"
            return dict(response=response, 
                        pvalue=np.nan, 
                        length=length,
                        comment=comment)

    def _get_pvals(self, responses: list, lengths: list) -> tuple:
        pvals = []
        comments = []
        for response, length in zip(responses, lengths):
            r = self._test_response(response, length)
            pvals.append(float(r['pvalue']))
            comments.append(r['comment'])
        return pvals, comments


    def _get_responses(self, sentences: list, contexts: list) -> list:
        """

        Compute response and length of a text sentence 

        """
        assert len(sentences) == len(contexts)

        responses = []
        lengths = []
        for sent, ctx in tqdm(zip(sentences, contexts)):
            logging.debug(f"Testing sentence: {sent} | context: {ctx}")
            length = self._get_length(sent)
            if self.length_limit_policy == 'truncate':
                sent = truncae_to_max_no_tokens(sent, self.max_len)
            if length == 1:
                logging.warning(f"Sentence {sent} is too short. Skipping.")
                responses.append(np.nan)
                continue
            try:
                responses.append(self._test_sentence(sent, ctx))
            except:
                # something unusual happened...
                import pdb; pdb.set_trace()
            lengths.append(length)
        return responses, lengths

    def get_pvals(self, sentences: list, contexts: list) -> tuple:
        """

        logloss test of every (sentence, context) pair

        """
        assert len(sentences) == len(contexts)

        responses, lengths = self._get_responses(sentences, contexts)
        pvals, comments = self._get_pvals(responses, lengths)
        
        return pvals, responses, comments


    def testHC(self, sentences: list) -> float:
        pvals = np.array(self.get_pvals(sentences)[1])
        mt = MultiTest(pvals, stbl=self.HC_stbl)
        return mt.hc(gamma=0.4)[0]

    def testFisher(self, sentences: list) -> dict:
        pvals = np.array(self.get_pvals(sentences)[1])
        print(pvals)
        mt = MultiTest(pvals, stbl=self.HC_stbl)
        return dict(zip(['Fn', 'pvalue'], mt.fisher()))

    def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> tuple:
        pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
        if self.ignore_first_sentence:
            pvals[0] = np.nan
            logging.info('Ignoring the first sentence.')
            comments[0] = "ignored (first sentence)"
        
        df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals,
                           'context': lo_contexts, 'comment': comments},
                          index=range(len(lo_chunks)))
        df_test = df[~df.pvalue.isna()]
        if df_test.empty:
            logging.warning('No valid chunks to test.')
            return None, df
        return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df

    def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
        mt, df = self._test_chunked_doc(lo_chunks, lo_contexts)
        if mt is None:
            hc = np.nan
            fisher = (np.nan, np.nan)
            df['mask'] = pd.NA
        else:
            hc, hct = mt.hc(gamma=0.4)
            fisher = mt.fisher()
            df['mask'] = df['pvalue'] <= hct
        if dashboard:
            mt.hc_dashboard(gamma=0.4)
        return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])

    def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
        return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard)