idkash1 commited on
Commit
d74e15b
·
verified ·
1 Parent(s): 92eab78

Update src/DetectLM.py

Browse files
Files changed (1) hide show
  1. src/DetectLM.py +177 -177
src/DetectLM.py CHANGED
@@ -1,178 +1,178 @@
1
- import numpy as np
2
- import pandas as pd
3
- from multitest import MultiTest
4
- from tqdm import tqdm
5
- import logging
6
-
7
-
8
- def truncae_to_max_no_tokens(text, max_no_tokens):
9
- return " ".join(text.split()[:max_no_tokens])
10
-
11
-
12
- class DetectLM(object):
13
- def __init__(self, sentence_detection_function, survival_function_per_length,
14
- min_len=4, max_len=100, HC_type="stbl",
15
- length_limit_policy='truncate', ignore_first_sentence=False):
16
- """
17
- Test for the presence of sentences of irregular origin as reflected by the
18
- sentence_detection_function. The test is based on the sentence detection function
19
- and the P-values obtained from the survival function of the detector's responses.
20
-
21
- Args:
22
- ----
23
- :sentence_detection_function: a function returning the response of the text
24
- under the detector. Typically, the response is a logloss value under some language model.
25
- :survival_function_per_length: survival_function_per_length(l, x) is the probability of the language
26
- model to produce a sentence value as extreme as x or more when the sentence s is the input to
27
- the detector. The function is defined for every sentence length l.
28
- The detector can also recieve a context c, in which case the input is the pair (s, c).
29
- :length_limit_policy: When a sentence exceeds ``max_len``, we can:
30
- 'truncate': truncate sentence to the maximal length :max_len
31
- 'ignore': do not evaluate the response and P-value for this sentence
32
- 'max_available': use the logloss function of the maximal available length
33
- :ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming
34
- context of the form previous sentence.
35
- """
36
-
37
- self.survival_function_per_length = survival_function_per_length
38
- self.sentence_detector = sentence_detection_function
39
- self.min_len = min_len
40
- self.max_len = max_len
41
- self.length_limit_policy = length_limit_policy
42
- self.ignore_first_sentence = ignore_first_sentence
43
- self.HC_stbl = True if HC_type == 'stbl' else False
44
-
45
- def _logperp(self, sent: str, context=None) -> float:
46
- return float(self.sentence_detector(sent, context))
47
-
48
- def _test_sentence(self, sentence: str, context=None):
49
- return self._logperp(sentence, context)
50
-
51
- def _get_length(self, sentence: str):
52
- return len(sentence.split())
53
-
54
- def _test_response(self, response: float, length: int):
55
- """
56
- Args:
57
- response: sentence logloss
58
- length: sentence length in tokens
59
-
60
- Returns:
61
- pvals: P-value of the logloss of the sentence
62
- comments: comment on the P-value
63
- """
64
- if self.min_len <= length:
65
- comment = "OK"
66
- if length > self.max_len: # in case length exceeds specifications...
67
- if self.length_limit_policy == 'truncate':
68
- length = self.max_len
69
- comment = f"truncated to {self.max_len} tokens"
70
- elif self.length_limit_policy == 'ignore':
71
- comment = "ignored (above maximum limit)"
72
- return np.nan, np.nan, comment
73
- elif self.length_limit_policy == 'max_available':
74
- comment = "exceeding length limit; resorting to max-available length"
75
- length = self.max_len
76
- pval = self.survival_function_per_length(length, response)
77
- assert pval >= 0, "Negative P-value. Something is wrong."
78
- return dict(response=response,
79
- pvalue=pval,
80
- length=length,
81
- comment=comment)
82
- else:
83
- comment = "ignored (below minimal length)"
84
- return dict(response=response,
85
- pvalue=np.nan,
86
- length=length,
87
- comment=comment)
88
-
89
- def _get_pvals(self, responses: list, lengths: list) -> tuple:
90
- pvals = []
91
- comments = []
92
- for response, length in zip(responses, lengths):
93
- r = self._test_response(response, length)
94
- pvals.append(float(r['pvalue']))
95
- comments.append(r['comment'])
96
- return pvals, comments
97
-
98
-
99
- def _get_responses(self, sentences: list, contexts: list) -> list:
100
- """
101
- Compute response and length of a text sentence
102
- """
103
- assert len(sentences) == len(contexts)
104
-
105
- responses = []
106
- lengths = []
107
- for sent, ctx in tqdm(zip(sentences, contexts)):
108
- logging.debug(f"Testing sentence: {sent} | context: {ctx}")
109
- length = self._get_length(sent)
110
- if self.length_limit_policy == 'truncate':
111
- sent = truncae_to_max_no_tokens(sent, self.max_len)
112
- if length == 1:
113
- logging.warning(f"Sentence {sent} is too short. Skipping.")
114
- responses.append(np.nan)
115
- continue
116
- try:
117
- responses.append(self._test_sentence(sent, ctx))
118
- except:
119
- # something unusual happened...
120
- import pdb; pdb.set_trace()
121
- lengths.append(length)
122
- return responses, lengths
123
-
124
- def get_pvals(self, sentences: list, contexts: list) -> tuple:
125
- """
126
- logloss test of every (sentence, context) pair
127
- """
128
- assert len(sentences) == len(contexts)
129
-
130
- responses, lengths = self._get_responses(sentences, contexts)
131
- pvals, comments = self._get_pvals(responses, lengths)
132
-
133
- return pvals, responses, comments
134
-
135
-
136
- def testHC(self, sentences: list) -> float:
137
- pvals = np.array(self.get_pvals(sentences)[1])
138
- mt = MultiTest(pvals, stbl=self.HC_stbl)
139
- return mt.hc(gamma=0.4)[0]
140
-
141
- def testFisher(self, sentences: list) -> dict:
142
- pvals = np.array(self.get_pvals(sentences)[1])
143
- print(pvals)
144
- mt = MultiTest(pvals, stbl=self.HC_stbl)
145
- return dict(zip(['Fn', 'pvalue'], mt.fisher()))
146
-
147
- def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> tuple:
148
- pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
149
- if self.ignore_first_sentence:
150
- pvals[0] = np.nan
151
- logging.info('Ignoring the first sentence.')
152
- comments[0] = "ignored (first sentence)"
153
-
154
- df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals,
155
- 'context': lo_contexts, 'comment': comments},
156
- index=range(len(lo_chunks)))
157
- df_test = df[~df.pvalue.isna()]
158
- if df_test.empty:
159
- logging.warning('No valid chunks to test.')
160
- return None, df
161
- return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df
162
-
163
- def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
164
- mt, df = self._test_chunked_doc(lo_chunks, lo_contexts)
165
- if mt is None:
166
- hc = np.nan
167
- fisher = (np.nan, np.nan)
168
- df['mask'] = pd.NA
169
- else:
170
- hc, hct = mt.hc(gamma=0.4)
171
- fisher = mt.fisher()
172
- df['mask'] = df['pvalue'] <= hct
173
- if dashboard:
174
- mt.hc_dashboard(gamma=0.4)
175
- return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])
176
-
177
- def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
178
  return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard)
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from multitest import MultiTest
4
+ from tqdm import tqdm
5
+ import logging
6
+
7
+
8
+ def truncae_to_max_no_tokens(text, max_no_tokens):
9
+ return " ".join(text.split()[:max_no_tokens])
10
+
11
+
12
+ class DetectLM(object):
13
+ def __init__(self, sentence_detection_function, survival_function_per_length,
14
+ min_len=4, max_len=100, HC_type="stbl",
15
+ length_limit_policy='truncate', ignore_first_sentence=False):
16
+ """
17
+ Test for the presence of sentences of irregular origin as reflected by the
18
+ sentence_detection_function. The test is based on the sentence detection function
19
+ and the P-values obtained from the survival function of the detector's responses.
20
+
21
+ Args:
22
+ ----
23
+ :sentence_detection_function: a function returning the response of the text
24
+ under the detector. Typically, the response is a logloss value under some language model.
25
+ :survival_function_per_length: survival_function_per_length(l, x) is the probability of the language
26
+ model to produce a sentence value as extreme as x or more when the sentence s is the input to
27
+ the detector. The function is defined for every sentence length l.
28
+ The detector can also recieve a context c, in which case the input is the pair (s, c).
29
+ :length_limit_policy: When a sentence exceeds ``max_len``, we can:
30
+ 'truncate': truncate sentence to the maximal length :max_len
31
+ 'ignore': do not evaluate the response and P-value for this sentence
32
+ 'max_available': use the logloss function of the maximal available length
33
+ :ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming
34
+ context of the form previous sentence.
35
+ """
36
+
37
+ self.survival_function_per_length = survival_function_per_length
38
+ self.sentence_detector = sentence_detection_function
39
+ self.min_len = min_len
40
+ self.max_len = max_len
41
+ self.length_limit_policy = length_limit_policy
42
+ self.ignore_first_sentence = ignore_first_sentence
43
+ self.HC_stbl = True if HC_type == 'stbl' else False
44
+
45
+ def _logperp(self, sent: str, context=None) -> float:
46
+ return float(self.sentence_detector(sent, context))
47
+
48
+ def _test_sentence(self, sentence: str, context=None):
49
+ return self._logperp(sentence, context)
50
+
51
+ def _get_length(self, sentence: str):
52
+ return len(sentence.split())
53
+
54
+ def _test_response(self, response: float, length: int):
55
+ """
56
+ Args:
57
+ response: sentence logloss
58
+ length: sentence length in tokens
59
+
60
+ Returns:
61
+ pvals: P-value of the logloss of the sentence
62
+ comments: comment on the P-value
63
+ """
64
+ if self.min_len <= length:
65
+ comment = "OK"
66
+ if length > self.max_len: # in case length exceeds specifications...
67
+ if self.length_limit_policy == 'truncate':
68
+ length = self.max_len
69
+ comment = f"truncated to {self.max_len} tokens"
70
+ elif self.length_limit_policy == 'ignore':
71
+ comment = "ignored (above maximum limit)"
72
+ return np.nan, np.nan, comment
73
+ elif self.length_limit_policy == 'max_available':
74
+ comment = "exceeding length limit; resorting to max-available length"
75
+ length = self.max_len
76
+ pval = self.survival_function_per_length(length, response)
77
+ assert pval >= 0, "Negative P-value. Something is wrong."
78
+ return dict(response=response,
79
+ pvalue=pval,
80
+ length=length,
81
+ comment=comment)
82
+ else:
83
+ comment = "ignored (below minimal length)"
84
+ return dict(response=response,
85
+ pvalue=np.nan,
86
+ length=length,
87
+ comment=comment)
88
+
89
+ def _get_pvals(self, responses: list, lengths: list) -> tuple:
90
+ pvals = []
91
+ comments = []
92
+ for response, length in zip(responses, lengths):
93
+ r = self._test_response(response, length)
94
+ pvals.append(float(r['pvalue']))
95
+ comments.append(r['comment'])
96
+ return pvals, comments
97
+
98
+
99
+ def _get_responses(self, sentences: list, contexts: list) -> list:
100
+ """
101
+ Compute response and length of a text sentence
102
+ """
103
+ assert len(sentences) == len(contexts)
104
+
105
+ responses = []
106
+ lengths = []
107
+ for sent, ctx in tqdm(zip(sentences, contexts)):
108
+ logging.debug(f"Testing sentence: {sent} | context: {ctx}")
109
+ length = self._get_length(sent)
110
+ if self.length_limit_policy == 'truncate':
111
+ sent = truncae_to_max_no_tokens(sent, self.max_len)
112
+ if length == 1:
113
+ logging.warning(f"Sentence {sent} is too short. Skipping.")
114
+ responses.append(np.nan)
115
+ continue
116
+ try:
117
+ responses.append(self._test_sentence(sent, ctx))
118
+ except:
119
+ # something unusual happened...
120
+ import pdb; pdb.set_trace()
121
+ lengths.append(length)
122
+ return responses, lengths
123
+
124
+ def get_pvals(self, sentences: list, contexts: list) -> tuple:
125
+ """
126
+ logloss test of every (sentence, context) pair
127
+ """
128
+ assert len(sentences) == len(contexts)
129
+
130
+ responses, lengths = self._get_responses(sentences, contexts)
131
+ pvals, comments = self._get_pvals(responses, lengths)
132
+
133
+ return pvals, responses, comments
134
+
135
+
136
+ def testHC(self, sentences: list) -> float:
137
+ pvals = np.array(self.get_pvals(sentences)[1])
138
+ mt = MultiTest(pvals, stbl=self.HC_stbl)
139
+ return mt.hc(gamma=0.4)[0]
140
+
141
+ def testFisher(self, sentences: list) -> dict:
142
+ pvals = np.array(self.get_pvals(sentences)[1])
143
+ print(pvals)
144
+ mt = MultiTest(pvals, stbl=self.HC_stbl)
145
+ return dict(zip(['Fn', 'pvalue'], mt.fisher()))
146
+
147
+ def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> tuple:
148
+ pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
149
+ if self.ignore_first_sentence:
150
+ pvals[0] = np.nan
151
+ logging.info('Ignoring the first sentence.')
152
+ comments[0] = "ignored (first sentence)"
153
+
154
+ df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals,
155
+ 'context': lo_contexts, 'comment': comments},
156
+ index=range(len(lo_chunks)))
157
+ df_test = df[~df.pvalue.isna()]
158
+ if df_test.empty:
159
+ logging.warning('No valid chunks to test.')
160
+ return None, df
161
+ return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df
162
+
163
+ def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
164
+ mt, df = self._test_chunked_doc(lo_chunks, lo_contexts)
165
+ if mt is None:
166
+ hc = np.nan
167
+ fisher = (np.nan, np.nan)
168
+ df['mask'] = pd.NA
169
+ else:
170
+ hc, hct = mt.hc(gamma=0.4)
171
+ fisher = mt.fisher()
172
+ df['mask'] = df['pvalue'] <= hct
173
+ if dashboard:
174
+ mt.hc_dashboard(gamma=0.4)
175
+ return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])
176
+
177
+ def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
178
  return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard)