File size: 10,411 Bytes
d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b e24467d d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b 689fec3 d74e15b e24467d d74e15b e24467d 689fec3 73bcc92 689fec3 d74e15b 689fec3 73bcc92 689fec3 d74e15b e086f23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import numpy as np
import pandas as pd
from multitest import MultiTest
from tqdm import tqdm
import logging
import json
import re
GAMMA = 0.45
def truncae_to_max_no_tokens(text, max_no_tokens):
return " ".join(text.split()[:max_no_tokens])
class DetectLM(object):
def __init__(self, sentence_detection_function, survival_function_per_length,
min_len=1, max_len=100, HC_type="stbl", gamma=GAMMA,
length_limit_policy='truncate', ignore_first_sentence=False, cache_logloss_path=''):
"""
Test for the presence of sentences of irregular origin as reflected by the
sentence_detection_function. The test is based on the sentence detection function
and the P-values obtained from the survival function of the detector's responses.
Args:
----
:sentence_detection_function: a function returning the response of the text
under the detector. Typically, the response is a logloss value under some language model.
:survival_function_per_length: survival_function_per_length(l, x) is the probability of the language
model to produce a sentence value as extreme as x or more when the sentence s is the input to
the detector. The function is defined for every sentence length l.
The detector can also recieve a context c, in which case the input is the pair (s, c).
:length_limit_policy: When a sentence exceeds ``max_len``, we can:
'truncate': truncate sentence to the maximal length :max_len
'ignore': do not evaluate the response and P-value for this sentence
'max_available': use the logloss function of the maximal available length
:ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming
that the first sentence is a title or a header or a context of the form previous sentence.
:HC_type: 'stbl' True for the 2008 HC version, otherwise uses the 2004 version.
:gamma: the gamma parameter of the HC test.
:cache_logloss_path: cache dict to restore the logloss faster
"""
self.survival_function_per_length = survival_function_per_length
self.sentence_detector = sentence_detection_function
self.min_len = min_len
self.max_len = max_len
self.length_limit_policy = length_limit_policy
self.ignore_first_sentence = ignore_first_sentence
self.HC_stbl = True if HC_type == 'stbl' else False
self.gamma = gamma
# Idan 26/05/204
self.cache_logloss_path = cache_logloss_path
try:
# Load the dictionary from the file
with open(self.cache_logloss_path, 'r') as file:
self.cache_logloss = json.load(file)
except:
print('Could not find cache file')
self.cache_logloss = None
def _logperp(self, sent: str, context=None) -> float:
return float(self.sentence_detector(sent, context))
def _test_sentence(self, sentence: str, context=None):
return self._logperp(sentence, context)
def _get_length(self, sentence: str):
return len(sentence.split())
def _test_response(self, response: float, length: int):
"""
Args:
response: sentence logloss
length: sentence length in tokens
Returns:
pvals: P-value of the logloss of the sentence
comments: comment on the P-value
"""
if self.min_len <= length:
comment = "OK"
if length > self.max_len: # in case length exceeds specifications...
if self.length_limit_policy == 'truncate':
length = self.max_len
comment = f"truncated to {self.max_len} tokens"
elif self.length_limit_policy == 'ignore':
comment = "ignored (above maximum limit)"
return np.nan, np.nan, comment
elif self.length_limit_policy == 'max_available':
comment = "exceeding length limit; resorting to max-available length"
length = self.max_len
pval = self.survival_function_per_length(length, response)
try:
assert pval >= 0, "Negative P-value. Something is wrong."
except:
import pdb; pdb.set_trace()
return dict(response=response,
pvalue=pval,
length=length,
comment=comment)
else:
comment = "ignored (below minimal length)"
return dict(response=response,
pvalue=np.nan,
length=length,
comment=comment)
def _get_pvals(self, responses: list, lengths: list) -> tuple:
"""
Pvalues from responses and lengths
"""
pvals = []
comments = []
for response, length in zip(responses, lengths):
if not np.isnan(response):
r = self._test_response(response, length)
else:
r = dict(response=response, pvalue=np.nan, length=length, comment="ignored (no response)")
pvals.append(float(r['pvalue']))
comments.append(r['comment'])
return pvals, comments
def clean_string(self, s):
# Remove escape characters
s = re.sub(r'\\[nrt]', '', s)
# Strip leading and trailing spaces and quotes
s = s.strip().strip("'")
# Convert to lower case
return s.lower()
def _get_logloss_cache(self, sent: str) -> float:
sent = sent.strip()
if self.cache_logloss is None: return None
if sent not in self.cache_logloss: return None
return self.cache_logloss[sent]
def _get_responses(self, sentences: list, contexts: list) -> list:
"""
Compute response and length of a every sentence in a list
"""
assert len(sentences) == len(contexts)
responses = []
lengths = []
for sent, ctx in tqdm(zip(sentences, contexts)):
logging.debug(f"Testing sentence: {sent} | context: {ctx}")
length = self._get_length(sent)
if self.length_limit_policy == 'truncate':
sent = truncae_to_max_no_tokens(sent, self.max_len)
# if length == 1:
# logging.warning(f"Sentence {sent} is too short. Skipping.")
# responses.append(np.nan)
# continue
try:
# Try getting logloss from cache
sentence_response = self._get_logloss_cache(self.clean_string(sent))
if sentence_response != None:
responses.append(sentence_response)
else: # If sentence not found
current_response = self._test_sentence(sent, ctx)
responses.append(current_response)
except:
# something unusual has happened...
import pdb; pdb.set_trace()
lengths.append(length)
return responses, lengths
def get_pvals(self, sentences: list, contexts: list) -> tuple:
"""
logloss test of every (sentence, context) pair
"""
assert len(sentences) == len(contexts)
responses, lengths = self._get_responses(sentences, contexts)
pvals, comments = self._get_pvals(responses, lengths)
return pvals, responses, comments
def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> (MultiTest, pd.DataFrame):
pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
if self.ignore_first_sentence:
pvals[0] = np.nan
logging.info('Ignoring the first sentence.')
comments[0] = "ignored (first sentence)"
df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals,
'context': lo_contexts, 'comment': comments},
index=range(len(lo_chunks)))
df_test = df[~df.pvalue.isna()]
if df_test.empty:
logging.warning('No valid chunks to test.')
return None, df
return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df
def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
mt, df = self._test_chunked_doc(lo_chunks, lo_contexts)
if mt is None:
hc = np.nan
fisher = (np.nan, np.nan)
df['mask'] = pd.NA
else:
hc, hct = mt.hc(gamma=self.gamma)
fisher = mt.fisher()
df['mask'] = df['pvalue'] <= hct
if dashboard:
mt.hc_dashboard(gamma=self.gamma)
dc = dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1], minP=mt.minp())
return dc
def from_responses(self, responses: list, lengths: list, dashboard=False) -> dict:
"""
Compute P-values from responses and lengths
"""
pvals, comments = self._get_pvals(responses, lengths)
if self.ignore_first_sentence:
pvals[0] = np.nan
logging.info('Ignoring the first sentence.')
comments[0] = "ignored (first sentence)"
df = pd.DataFrame({'response': responses, 'pvalue': pvals, 'comment': comments},
index=range(len(responses)))
df_test = df[~df.pvalue.isna()]
if df_test.empty:
logging.warning('No valid chunks to test.')
return None, df
mt = MultiTest(df_test.pvalue, stbl=self.HC_stbl)
if mt is None:
hc = np.nan
fisher = (np.nan, np.nan)
df['mask'] = pd.NA
else:
hc, hct = mt.hc(gamma=self.gamma)
fisher = mt.fisher()
df['mask'] = df['pvalue'] <= hct
if dashboard:
mt.hc_dashboard(gamma=self.gamma)
return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])
def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard) |