idkash1 commited on
Commit
689fec3
·
verified ·
1 Parent(s): c8b4788

Update src/DetectLM.py

Browse files
Files changed (1) hide show
  1. src/DetectLM.py +94 -28
src/DetectLM.py CHANGED
@@ -3,6 +3,9 @@ import pandas as pd
3
  from multitest import MultiTest
4
  from tqdm import tqdm
5
  import logging
 
 
 
6
 
7
 
8
  def truncae_to_max_no_tokens(text, max_no_tokens):
@@ -11,8 +14,8 @@ def truncae_to_max_no_tokens(text, max_no_tokens):
11
 
12
  class DetectLM(object):
13
  def __init__(self, sentence_detection_function, survival_function_per_length,
14
- min_len=4, max_len=100, HC_type="stbl", gamma=0.15,
15
- length_limit_policy='truncate', ignore_first_sentence=False):
16
  """
17
  Test for the presence of sentences of irregular origin as reflected by the
18
  sentence_detection_function. The test is based on the sentence detection function
@@ -31,7 +34,10 @@ class DetectLM(object):
31
  'ignore': do not evaluate the response and P-value for this sentence
32
  'max_available': use the logloss function of the maximal available length
33
  :ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming
34
- context of the form previous sentence.
 
 
 
35
  """
36
 
37
  self.survival_function_per_length = survival_function_per_length
@@ -43,6 +49,16 @@ class DetectLM(object):
43
  self.HC_stbl = True if HC_type == 'stbl' else False
44
  self.gamma = gamma
45
 
 
 
 
 
 
 
 
 
 
 
46
  def _logperp(self, sent: str, context=None) -> float:
47
  return float(self.sentence_detector(sent, context))
48
 
@@ -75,7 +91,11 @@ class DetectLM(object):
75
  comment = "exceeding length limit; resorting to max-available length"
76
  length = self.max_len
77
  pval = self.survival_function_per_length(length, response)
78
- assert pval >= 0, "Negative P-value. Something is wrong."
 
 
 
 
79
  return dict(response=response,
80
  pvalue=pval,
81
  length=length,
@@ -88,18 +108,37 @@ class DetectLM(object):
88
  comment=comment)
89
 
90
  def _get_pvals(self, responses: list, lengths: list) -> tuple:
 
 
 
91
  pvals = []
92
  comments = []
93
  for response, length in zip(responses, lengths):
94
- r = self._test_response(response, length)
 
 
 
95
  pvals.append(float(r['pvalue']))
96
  comments.append(r['comment'])
97
  return pvals, comments
98
 
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def _get_responses(self, sentences: list, contexts: list) -> list:
101
  """
102
- Compute response and length of a text sentence
103
  """
104
  assert len(sentences) == len(contexts)
105
 
@@ -110,14 +149,20 @@ class DetectLM(object):
110
  length = self._get_length(sent)
111
  if self.length_limit_policy == 'truncate':
112
  sent = truncae_to_max_no_tokens(sent, self.max_len)
113
- if length == 1:
114
- logging.warning(f"Sentence {sent} is too short. Skipping.")
115
- responses.append(np.nan)
116
- continue
117
  try:
118
- responses.append(self._test_sentence(sent, ctx))
 
 
 
 
 
 
119
  except:
120
- # something unusual happened...
121
  import pdb; pdb.set_trace()
122
  lengths.append(length)
123
  return responses, lengths
@@ -130,22 +175,9 @@ class DetectLM(object):
130
 
131
  responses, lengths = self._get_responses(sentences, contexts)
132
  pvals, comments = self._get_pvals(responses, lengths)
133
-
134
  return pvals, responses, comments
135
 
136
-
137
- def testHC(self, sentences: list) -> float:
138
- pvals = np.array(self.get_pvals(sentences)[1])
139
- mt = MultiTest(pvals, stbl=self.HC_stbl)
140
- return mt.hc(gamma=self.gamma)[0]
141
-
142
- def testFisher(self, sentences: list) -> dict:
143
- pvals = np.array(self.get_pvals(sentences)[1])
144
- print(pvals)
145
- mt = MultiTest(pvals, stbl=self.HC_stbl)
146
- return dict(zip(['Fn', 'pvalue'], mt.fisher()))
147
-
148
- def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> tuple:
149
  pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
150
  if self.ignore_first_sentence:
151
  pvals[0] = np.nan
@@ -173,7 +205,41 @@ class DetectLM(object):
173
  df['mask'] = df['pvalue'] <= hct
174
  if dashboard:
175
  mt.hc_dashboard(gamma=self.gamma)
176
- return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])
 
 
 
 
 
 
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
179
  return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard)
 
3
  from multitest import MultiTest
4
  from tqdm import tqdm
5
  import logging
6
+ import json
7
+ import re
8
+ GAMMA = 0.45
9
 
10
 
11
  def truncae_to_max_no_tokens(text, max_no_tokens):
 
14
 
15
  class DetectLM(object):
16
  def __init__(self, sentence_detection_function, survival_function_per_length,
17
+ min_len=1, max_len=100, HC_type="stbl", gamma=GAMMA,
18
+ length_limit_policy='truncate', ignore_first_sentence=False, cache_logloss_path=''):
19
  """
20
  Test for the presence of sentences of irregular origin as reflected by the
21
  sentence_detection_function. The test is based on the sentence detection function
 
34
  'ignore': do not evaluate the response and P-value for this sentence
35
  'max_available': use the logloss function of the maximal available length
36
  :ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming
37
+ that the first sentence is a title or a header or a context of the form previous sentence.
38
+ :HC_type: 'stbl' True for the 2008 HC version, otherwise uses the 2004 version.
39
+ :gamma: the gamma parameter of the HC test.
40
+ :cache_logloss_path: cache dict to restore the logloss faster
41
  """
42
 
43
  self.survival_function_per_length = survival_function_per_length
 
49
  self.HC_stbl = True if HC_type == 'stbl' else False
50
  self.gamma = gamma
51
 
52
+ # Idan 26/05/204
53
+ self.cache_logloss_path = cache_logloss_path
54
+ try:
55
+ # Load the dictionary from the file
56
+ with open(self.cache_logloss_path, 'r') as file:
57
+ self.cache_logloss = json.load(file)
58
+ except:
59
+ print('Could not find cache file')
60
+ self.cache_logloss = None
61
+
62
  def _logperp(self, sent: str, context=None) -> float:
63
  return float(self.sentence_detector(sent, context))
64
 
 
91
  comment = "exceeding length limit; resorting to max-available length"
92
  length = self.max_len
93
  pval = self.survival_function_per_length(length, response)
94
+ try:
95
+ assert pval >= 0, "Negative P-value. Something is wrong."
96
+ except:
97
+ import pdb; pdb.set_trace()
98
+
99
  return dict(response=response,
100
  pvalue=pval,
101
  length=length,
 
108
  comment=comment)
109
 
110
  def _get_pvals(self, responses: list, lengths: list) -> tuple:
111
+ """
112
+ Pvalues from responses and lengths
113
+ """
114
  pvals = []
115
  comments = []
116
  for response, length in zip(responses, lengths):
117
+ if not np.isnan(response):
118
+ r = self._test_response(response, length)
119
+ else:
120
+ r = dict(response=response, pvalue=np.nan, length=length, comment="ignored (no response)")
121
  pvals.append(float(r['pvalue']))
122
  comments.append(r['comment'])
123
  return pvals, comments
124
 
125
+ def clean_string(self, s):
126
+ # Remove escape characters
127
+ s = re.sub(r'\\[nrt]', '', s)
128
+ # Strip leading and trailing spaces and quotes
129
+ s = s.strip().strip("'")
130
+ # Convert to lower case
131
+ return s.lower()
132
+
133
+ def _get_logloss_cache(self, sent: str) -> float:
134
+ sent = sent.strip()
135
+ if self.cache_logloss is None: return None
136
+ if sent not in self.cache_logloss: return None
137
+ return self.cache_logloss[sent]
138
+
139
  def _get_responses(self, sentences: list, contexts: list) -> list:
140
  """
141
+ Compute response and length of a every sentence in a list
142
  """
143
  assert len(sentences) == len(contexts)
144
 
 
149
  length = self._get_length(sent)
150
  if self.length_limit_policy == 'truncate':
151
  sent = truncae_to_max_no_tokens(sent, self.max_len)
152
+ # if length == 1:
153
+ # logging.warning(f"Sentence {sent} is too short. Skipping.")
154
+ # responses.append(np.nan)
155
+ # continue
156
  try:
157
+ # Try getting logloss from cache
158
+ sentence_response = self._get_logloss_cache(self.clean_string(sent))
159
+ if sentence_response != None:
160
+ responses.append(sentence_response)
161
+ else: # If sentence not found
162
+ current_response = self._test_sentence(sent, ctx)
163
+ responses.append(current_response)
164
  except:
165
+ # something unusual has happened...
166
  import pdb; pdb.set_trace()
167
  lengths.append(length)
168
  return responses, lengths
 
175
 
176
  responses, lengths = self._get_responses(sentences, contexts)
177
  pvals, comments = self._get_pvals(responses, lengths)
 
178
  return pvals, responses, comments
179
 
180
+ def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> (MultiTest, pd.DataFrame):
 
 
 
 
 
 
 
 
 
 
 
 
181
  pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
182
  if self.ignore_first_sentence:
183
  pvals[0] = np.nan
 
205
  df['mask'] = df['pvalue'] <= hct
206
  if dashboard:
207
  mt.hc_dashboard(gamma=self.gamma)
208
+
209
+ dc = dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1], minP=mt.minp(), bonf=mt.bonfferoni())
210
+ return dc
211
+
212
+ def from_responses(self, responses: list, lengths: list, dashboard=False) -> dict:
213
+ """
214
+ Compute P-values from responses and lengths
215
+ """
216
 
217
+ pvals, comments = self._get_pvals(responses, lengths)
218
+ if self.ignore_first_sentence:
219
+ pvals[0] = np.nan
220
+ logging.info('Ignoring the first sentence.')
221
+ comments[0] = "ignored (first sentence)"
222
+
223
+ df = pd.DataFrame({'response': responses, 'pvalue': pvals, 'comment': comments},
224
+ index=range(len(responses)))
225
+ df_test = df[~df.pvalue.isna()]
226
+ if df_test.empty:
227
+ logging.warning('No valid chunks to test.')
228
+ return None, df
229
+ mt = MultiTest(df_test.pvalue, stbl=self.HC_stbl)
230
+
231
+ if mt is None:
232
+ hc = np.nan
233
+ fisher = (np.nan, np.nan)
234
+ df['mask'] = pd.NA
235
+ else:
236
+ hc, hct = mt.hc(gamma=self.gamma)
237
+ fisher = mt.fisher()
238
+ bonferroni = mt.bonfferoni()
239
+ df['mask'] = df['pvalue'] <= hct
240
+ if dashboard:
241
+ mt.hc_dashboard(gamma=self.gamma)
242
+ return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1], bonf=bonferroni)
243
+
244
  def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
245
  return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard)