Spaces:
Runtime error
Runtime error
File size: 7,850 Bytes
7a0ff7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import numpy as np
import pandas as pd
from multitest import MultiTest
from tqdm import tqdm
import logging
def truncae_to_max_no_tokens(text, max_no_tokens):
return " ".join(text.split()[:max_no_tokens])
class DetectLM(object):
def __init__(self, sentence_detection_function, survival_function_per_length,
min_len=4, max_len=100, HC_type="stbl",
length_limit_policy='truncate', ignore_first_sentence=False):
"""
Test for the presence of sentences of irregular origin as reflected by the
sentence_detection_function. The test is based on the sentence detection function
and the P-values obtained from the survival function of the detector's responses.
Args:
----
:sentence_detection_function: a function returning the response of the text
under the detector. Typically, the response is a logloss value under some language model.
:survival_function_per_length: survival_function_per_length(l, x) is the probability of the language
model to produce a sentence value as extreme as x or more when the sentence s is the input to
the detector. The function is defined for every sentence length l.
The detector can also recieve a context c, in which case the input is the pair (s, c).
:length_limit_policy: When a sentence exceeds ``max_len``, we can:
'truncate': truncate sentence to the maximal length :max_len
'ignore': do not evaluate the response and P-value for this sentence
'max_available': use the logloss function of the maximal available length
:ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming
context of the form previous sentence.
"""
self.survival_function_per_length = survival_function_per_length
self.sentence_detector = sentence_detection_function
self.min_len = min_len
self.max_len = max_len
self.length_limit_policy = length_limit_policy
self.ignore_first_sentence = ignore_first_sentence
self.HC_stbl = True if HC_type == 'stbl' else False
def _logperp(self, sent: str, context=None) -> float:
return float(self.sentence_detector(sent, context))
def _test_sentence(self, sentence: str, context=None):
return self._logperp(sentence, context)
def _get_length(self, sentence: str):
return len(sentence.split())
def _test_response(self, response: float, length: int):
"""
Args:
response: sentence logloss
length: sentence length in tokens
Returns:
pvals: P-value of the logloss of the sentence
comments: comment on the P-value
"""
if self.min_len <= length:
comment = "OK"
if length > self.max_len: # in case length exceeds specifications...
if self.length_limit_policy == 'truncate':
length = self.max_len
comment = f"truncated to {self.max_len} tokens"
elif self.length_limit_policy == 'ignore':
comment = "ignored (above maximum limit)"
return np.nan, np.nan, comment
elif self.length_limit_policy == 'max_available':
comment = "exceeding length limit; resorting to max-available length"
length = self.max_len
pval = self.survival_function_per_length(length, response)
assert pval >= 0, "Negative P-value. Something is wrong."
return dict(response=response,
pvalue=pval,
length=length,
comment=comment)
else:
comment = "ignored (below minimal length)"
return dict(response=response,
pvalue=np.nan,
length=length,
comment=comment)
def _get_pvals(self, responses: list, lengths: list) -> tuple:
pvals = []
comments = []
for response, length in zip(responses, lengths):
r = self._test_response(response, length)
pvals.append(float(r['pvalue']))
comments.append(r['comment'])
return pvals, comments
def _get_responses(self, sentences: list, contexts: list) -> list:
"""
Compute response and length of a text sentence
"""
assert len(sentences) == len(contexts)
responses = []
lengths = []
for sent, ctx in tqdm(zip(sentences, contexts)):
logging.debug(f"Testing sentence: {sent} | context: {ctx}")
length = self._get_length(sent)
if self.length_limit_policy == 'truncate':
sent = truncae_to_max_no_tokens(sent, self.max_len)
if length == 1:
logging.warning(f"Sentence {sent} is too short. Skipping.")
responses.append(np.nan)
continue
try:
responses.append(self._test_sentence(sent, ctx))
except:
# something unusual happened...
import pdb; pdb.set_trace()
lengths.append(length)
return responses, lengths
def get_pvals(self, sentences: list, contexts: list) -> tuple:
"""
logloss test of every (sentence, context) pair
"""
assert len(sentences) == len(contexts)
responses, lengths = self._get_responses(sentences, contexts)
pvals, comments = self._get_pvals(responses, lengths)
return pvals, responses, comments
def testHC(self, sentences: list) -> float:
pvals = np.array(self.get_pvals(sentences)[1])
mt = MultiTest(pvals, stbl=self.HC_stbl)
return mt.hc(gamma=0.4)[0]
def testFisher(self, sentences: list) -> dict:
pvals = np.array(self.get_pvals(sentences)[1])
print(pvals)
mt = MultiTest(pvals, stbl=self.HC_stbl)
return dict(zip(['Fn', 'pvalue'], mt.fisher()))
def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> tuple:
pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
if self.ignore_first_sentence:
pvals[0] = np.nan
logging.info('Ignoring the first sentence.')
comments[0] = "ignored (first sentence)"
df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals,
'context': lo_contexts, 'comment': comments},
index=range(len(lo_chunks)))
df_test = df[~df.pvalue.isna()]
if df_test.empty:
logging.warning('No valid chunks to test.')
return None, df
return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df
def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
mt, df = self._test_chunked_doc(lo_chunks, lo_contexts)
if mt is None:
hc = np.nan
fisher = (np.nan, np.nan)
df['mask'] = pd.NA
else:
hc, hct = mt.hc(gamma=0.4)
fisher = mt.fisher()
df['mask'] = df['pvalue'] <= hct
if dashboard:
mt.hc_dashboard(gamma=0.4)
return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])
def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard) |