|
import numpy as np |
|
import pandas as pd |
|
from multitest import MultiTest |
|
from tqdm import tqdm |
|
import logging |
|
import json |
|
import re |
|
GAMMA = 0.45 |
|
|
|
|
|
def truncae_to_max_no_tokens(text, max_no_tokens): |
|
return " ".join(text.split()[:max_no_tokens]) |
|
|
|
|
|
class DetectLM(object): |
|
def __init__(self, sentence_detection_function, survival_function_per_length, |
|
min_len=1, max_len=100, HC_type="stbl", gamma=GAMMA, |
|
length_limit_policy='truncate', ignore_first_sentence=False, cache_logloss_path=''): |
|
""" |
|
Test for the presence of sentences of irregular origin as reflected by the |
|
sentence_detection_function. The test is based on the sentence detection function |
|
and the P-values obtained from the survival function of the detector's responses. |
|
|
|
Args: |
|
---- |
|
:sentence_detection_function: a function returning the response of the text |
|
under the detector. Typically, the response is a logloss value under some language model. |
|
:survival_function_per_length: survival_function_per_length(l, x) is the probability of the language |
|
model to produce a sentence value as extreme as x or more when the sentence s is the input to |
|
the detector. The function is defined for every sentence length l. |
|
The detector can also recieve a context c, in which case the input is the pair (s, c). |
|
:length_limit_policy: When a sentence exceeds ``max_len``, we can: |
|
'truncate': truncate sentence to the maximal length :max_len |
|
'ignore': do not evaluate the response and P-value for this sentence |
|
'max_available': use the logloss function of the maximal available length |
|
:ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming |
|
that the first sentence is a title or a header or a context of the form previous sentence. |
|
:HC_type: 'stbl' True for the 2008 HC version, otherwise uses the 2004 version. |
|
:gamma: the gamma parameter of the HC test. |
|
:cache_logloss_path: cache dict to restore the logloss faster |
|
""" |
|
|
|
self.survival_function_per_length = survival_function_per_length |
|
self.sentence_detector = sentence_detection_function |
|
self.min_len = min_len |
|
self.max_len = max_len |
|
self.length_limit_policy = length_limit_policy |
|
self.ignore_first_sentence = ignore_first_sentence |
|
self.HC_stbl = True if HC_type == 'stbl' else False |
|
self.gamma = gamma |
|
|
|
|
|
self.cache_logloss_path = cache_logloss_path |
|
try: |
|
|
|
with open(self.cache_logloss_path, 'r') as file: |
|
self.cache_logloss = json.load(file) |
|
except: |
|
print('Could not find cache file') |
|
self.cache_logloss = None |
|
|
|
def _logperp(self, sent: str, context=None) -> float: |
|
return float(self.sentence_detector(sent, context)) |
|
|
|
def _test_sentence(self, sentence: str, context=None): |
|
return self._logperp(sentence, context) |
|
|
|
def _get_length(self, sentence: str): |
|
return len(sentence.split()) |
|
|
|
def _test_response(self, response: float, length: int): |
|
""" |
|
Args: |
|
response: sentence logloss |
|
length: sentence length in tokens |
|
|
|
Returns: |
|
pvals: P-value of the logloss of the sentence |
|
comments: comment on the P-value |
|
""" |
|
if self.min_len <= length: |
|
comment = "OK" |
|
if length > self.max_len: |
|
if self.length_limit_policy == 'truncate': |
|
length = self.max_len |
|
comment = f"truncated to {self.max_len} tokens" |
|
elif self.length_limit_policy == 'ignore': |
|
comment = "ignored (above maximum limit)" |
|
return np.nan, np.nan, comment |
|
elif self.length_limit_policy == 'max_available': |
|
comment = "exceeding length limit; resorting to max-available length" |
|
length = self.max_len |
|
pval = self.survival_function_per_length(length, response) |
|
try: |
|
assert pval >= 0, "Negative P-value. Something is wrong." |
|
except: |
|
import pdb; pdb.set_trace() |
|
|
|
return dict(response=response, |
|
pvalue=pval, |
|
length=length, |
|
comment=comment) |
|
else: |
|
comment = "ignored (below minimal length)" |
|
return dict(response=response, |
|
pvalue=np.nan, |
|
length=length, |
|
comment=comment) |
|
|
|
def _get_pvals(self, responses: list, lengths: list) -> tuple: |
|
""" |
|
Pvalues from responses and lengths |
|
""" |
|
pvals = [] |
|
comments = [] |
|
for response, length in zip(responses, lengths): |
|
if not np.isnan(response): |
|
r = self._test_response(response, length) |
|
else: |
|
r = dict(response=response, pvalue=np.nan, length=length, comment="ignored (no response)") |
|
pvals.append(float(r['pvalue'])) |
|
comments.append(r['comment']) |
|
return pvals, comments |
|
|
|
def clean_string(self, s): |
|
|
|
s = re.sub(r'\\[nrt]', '', s) |
|
|
|
s = s.strip().strip("'") |
|
|
|
return s.lower() |
|
|
|
def _get_logloss_cache(self, sent: str) -> float: |
|
sent = sent.strip() |
|
if self.cache_logloss is None: return None |
|
if sent not in self.cache_logloss: return None |
|
return self.cache_logloss[sent] |
|
|
|
def _get_responses(self, sentences: list, contexts: list) -> list: |
|
""" |
|
Compute response and length of a every sentence in a list |
|
""" |
|
assert len(sentences) == len(contexts) |
|
|
|
responses = [] |
|
lengths = [] |
|
for sent, ctx in tqdm(zip(sentences, contexts)): |
|
logging.debug(f"Testing sentence: {sent} | context: {ctx}") |
|
length = self._get_length(sent) |
|
if self.length_limit_policy == 'truncate': |
|
sent = truncae_to_max_no_tokens(sent, self.max_len) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
sentence_response = self._get_logloss_cache(self.clean_string(sent)) |
|
if sentence_response != None: |
|
responses.append(sentence_response) |
|
else: |
|
current_response = self._test_sentence(sent, ctx) |
|
responses.append(current_response) |
|
except: |
|
|
|
import pdb; pdb.set_trace() |
|
lengths.append(length) |
|
return responses, lengths |
|
|
|
def get_pvals(self, sentences: list, contexts: list) -> tuple: |
|
""" |
|
logloss test of every (sentence, context) pair |
|
""" |
|
assert len(sentences) == len(contexts) |
|
|
|
responses, lengths = self._get_responses(sentences, contexts) |
|
pvals, comments = self._get_pvals(responses, lengths) |
|
return pvals, responses, comments |
|
|
|
def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> (MultiTest, pd.DataFrame): |
|
pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts) |
|
if self.ignore_first_sentence: |
|
pvals[0] = np.nan |
|
logging.info('Ignoring the first sentence.') |
|
comments[0] = "ignored (first sentence)" |
|
|
|
df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals, |
|
'context': lo_contexts, 'comment': comments}, |
|
index=range(len(lo_chunks))) |
|
df_test = df[~df.pvalue.isna()] |
|
if df_test.empty: |
|
logging.warning('No valid chunks to test.') |
|
return None, df |
|
return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df |
|
|
|
def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict: |
|
mt, df = self._test_chunked_doc(lo_chunks, lo_contexts) |
|
if mt is None: |
|
hc = np.nan |
|
fisher = (np.nan, np.nan) |
|
df['mask'] = pd.NA |
|
else: |
|
hc, hct = mt.hc(gamma=self.gamma) |
|
fisher = mt.fisher() |
|
df['mask'] = df['pvalue'] <= hct |
|
if dashboard: |
|
mt.hc_dashboard(gamma=self.gamma) |
|
|
|
dc = dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1], minP=mt.minp()) |
|
return dc |
|
|
|
def from_responses(self, responses: list, lengths: list, dashboard=False) -> dict: |
|
""" |
|
Compute P-values from responses and lengths |
|
""" |
|
|
|
pvals, comments = self._get_pvals(responses, lengths) |
|
if self.ignore_first_sentence: |
|
pvals[0] = np.nan |
|
logging.info('Ignoring the first sentence.') |
|
comments[0] = "ignored (first sentence)" |
|
|
|
df = pd.DataFrame({'response': responses, 'pvalue': pvals, 'comment': comments}, |
|
index=range(len(responses))) |
|
df_test = df[~df.pvalue.isna()] |
|
if df_test.empty: |
|
logging.warning('No valid chunks to test.') |
|
return None, df |
|
mt = MultiTest(df_test.pvalue, stbl=self.HC_stbl) |
|
|
|
if mt is None: |
|
hc = np.nan |
|
fisher = (np.nan, np.nan) |
|
df['mask'] = pd.NA |
|
else: |
|
hc, hct = mt.hc(gamma=self.gamma) |
|
fisher = mt.fisher() |
|
df['mask'] = df['pvalue'] <= hct |
|
if dashboard: |
|
mt.hc_dashboard(gamma=self.gamma) |
|
return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1]) |
|
|
|
def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict: |
|
return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard) |