Upload 46 files
Browse files- .gitattributes +35 -35
- Dockerfile +20 -0
- README.md +11 -10
- app.py +76 -0
- cache/huggingface/New Text Document.txt +0 -0
- cache/matplotlib/New Text Document.txt +0 -0
- human_text_detect.py +169 -0
- nullData/GPT2XL_characters.pkl +3 -0
- nullData/GPT2XL_locations.pkl +3 -0
- nullData/GPT2XL_nature.pkl +3 -0
- nullData/GPT2XL_video_games_series_movies.pkl +3 -0
- nullData/GPT2XL_war.pkl +3 -0
- nullData/PHI2_characters.pkl +3 -0
- nullData/PHI2_locations.pkl +3 -0
- nullData/PHI2_nature.pkl +3 -0
- nullData/PHI2_video_games_series_movies.pkl +3 -0
- nullData/PHI2_war.pkl +3 -0
- requirements.txt +0 -0
- src/DetectLM.py +178 -0
- src/HC_survival_function.py +66 -0
- src/PerplexityEvaluator.py +34 -0
- src/PrepareArticles.py +74 -0
- src/PrepareSentenceContext.py +158 -0
- src/SentenceParser.py +31 -0
- src/__init__.py +0 -0
- src/__pycache__/DetectLM.cpython-310.pyc +0 -0
- src/__pycache__/DetectLM.cpython-38.pyc +0 -0
- src/__pycache__/HC_survival_function.cpython-310.pyc +0 -0
- src/__pycache__/HC_survival_function.cpython-38.pyc +0 -0
- src/__pycache__/PerplexityEvaluator.cpython-310.pyc +0 -0
- src/__pycache__/PerplexityEvaluator.cpython-312.pyc +0 -0
- src/__pycache__/PerplexityEvaluator.cpython-38.pyc +0 -0
- src/__pycache__/PrepareArticles.cpython-310.pyc +0 -0
- src/__pycache__/PrepareArticles.cpython-38.pyc +0 -0
- src/__pycache__/PrepareSentenceContext.cpython-310.pyc +0 -0
- src/__pycache__/PrepareSentenceContext.cpython-38.pyc +0 -0
- src/__pycache__/SentenceParser.cpython-310.pyc +0 -0
- src/__pycache__/SentenceParser.cpython-38.pyc +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/__init__.cpython-38.pyc +0 -0
- src/__pycache__/fit_survival_function.cpython-310.pyc +0 -0
- src/__pycache__/fit_survival_function.cpython-38.pyc +0 -0
- src/dataset_loaders.py +87 -0
- src/fit_survival_function.py +94 -0
- threshold_obj.pkl +3 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime
|
2 |
+
FROM python:3.8-slim
|
3 |
+
|
4 |
+
# Set the working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy the requirements file
|
8 |
+
COPY requirements.txt .
|
9 |
+
|
10 |
+
# Install dependencies
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
# Copy the rest of the app
|
14 |
+
COPY . .
|
15 |
+
|
16 |
+
# Expose the Flask port
|
17 |
+
EXPOSE 5000
|
18 |
+
|
19 |
+
# Run the application
|
20 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
-
---
|
2 |
-
title: Detect Edits In AI-Generated Text
|
3 |
-
emoji: 馃憗
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
1 |
+
---
|
2 |
+
title: Detect Edits In AI-Generated Text
|
3 |
+
emoji: 馃憗
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: mit
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://www.freecodecamp.org/news/how-to-setup-virtual-environments-in-python/
|
2 |
+
#https://www.youtube.com/watch?v=qbLc5a9jdXo&ab_channel=CalebCurry
|
3 |
+
#https://stackoverflow.com/questions/26368306/export-is-not-recognized-as-an-internal-or-external-command
|
4 |
+
#python3 -m venv .venv
|
5 |
+
#source .venv/bin/activate
|
6 |
+
#
|
7 |
+
#pip freeze > requirements.txt
|
8 |
+
#$env:FLASK_APP="application.py" #set FLASK_APP=application.py # export FLASK_APP=application.py
|
9 |
+
#set FLASK_ENV=development #export FLASK_ENV=production
|
10 |
+
#flask run #flask run --host=0.0.0.0
|
11 |
+
|
12 |
+
#pip install torchvision
|
13 |
+
|
14 |
+
from flask import Flask, request, jsonify
|
15 |
+
from flask_cors import CORS
|
16 |
+
import pandas
|
17 |
+
from human_text_detect import detect_human_text
|
18 |
+
|
19 |
+
app = Flask(__name__)
|
20 |
+
CORS(app)
|
21 |
+
|
22 |
+
@app.route('/')
|
23 |
+
def index():
|
24 |
+
return 'Hello'
|
25 |
+
|
26 |
+
@app.route('/detectHumanInAIText/checkText', methods=['POST'])
|
27 |
+
def check_text():
|
28 |
+
|
29 |
+
# Get data
|
30 |
+
print('Get data')
|
31 |
+
data = request.get_json()
|
32 |
+
text = data.get('text')
|
33 |
+
model_name = data.get('model')
|
34 |
+
topic = data.get('topic')
|
35 |
+
|
36 |
+
# Validate data
|
37 |
+
print('Validate data')
|
38 |
+
answer = validate_data(text, model_name, topic)
|
39 |
+
if answer != '':
|
40 |
+
return jsonify({'answer': answer}), 400
|
41 |
+
|
42 |
+
topic = check_topic(topic)
|
43 |
+
answer = detect_human_text(model_name, topic, text)
|
44 |
+
|
45 |
+
return jsonify({'answer': answer})
|
46 |
+
|
47 |
+
def validate_data(text, model_name, topic):
|
48 |
+
if text is None or text == '':
|
49 |
+
return 'Text is missing'
|
50 |
+
|
51 |
+
if model_name is None or model_name == '':
|
52 |
+
return 'Model name is missing'
|
53 |
+
|
54 |
+
if topic is None or topic == '':
|
55 |
+
return 'Topic is missing'
|
56 |
+
|
57 |
+
if model_name not in ['GPT2XL', 'PHI2']:
|
58 |
+
return f'Model {model_name} not supported'
|
59 |
+
|
60 |
+
if topic not in ['Characters', 'Locations', 'Nature', 'Video games', 'Series', 'Movies', 'War']:
|
61 |
+
return f'Topic {topic} not supported'
|
62 |
+
|
63 |
+
return ''
|
64 |
+
|
65 |
+
def check_topic(topic):
|
66 |
+
topic_dict = {
|
67 |
+
'Characters': 'characters',
|
68 |
+
'Locations': 'locations',
|
69 |
+
'Nature': 'nature',
|
70 |
+
'Video games': 'video_games_series_movies',
|
71 |
+
'Series': 'video_games_series_movies',
|
72 |
+
'Movies': 'video_games_series_movies',
|
73 |
+
'War': 'war'
|
74 |
+
}
|
75 |
+
|
76 |
+
return topic_dict[topic]
|
cache/huggingface/New Text Document.txt
ADDED
File without changes
|
cache/matplotlib/New Text Document.txt
ADDED
File without changes
|
human_text_detect.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pandas as pd
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import pickle
|
7 |
+
from src.DetectLM import DetectLM
|
8 |
+
from src.PerplexityEvaluator import PerplexityEvaluator
|
9 |
+
from src.PrepareArticles import PrepareArticles #Idan
|
10 |
+
from src.fit_survival_function import fit_per_length_survival_function
|
11 |
+
from glob import glob
|
12 |
+
import spacy
|
13 |
+
import re
|
14 |
+
|
15 |
+
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
|
18 |
+
|
19 |
+
def read_all_csv_files(pattern):
|
20 |
+
df = pd.DataFrame()
|
21 |
+
print(pattern)
|
22 |
+
for f in glob(pattern):
|
23 |
+
df = pd.concat([df, pd.read_csv(f)])
|
24 |
+
return df
|
25 |
+
|
26 |
+
|
27 |
+
def get_survival_function(df, G=101):
|
28 |
+
"""
|
29 |
+
Returns a survival function for every sentence length in tokens.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
:df: data frame with columns 'response' and 'length'
|
33 |
+
:G: number of interpolation points
|
34 |
+
|
35 |
+
Return:
|
36 |
+
bivariate function (length, responce) -> (0,1)
|
37 |
+
|
38 |
+
"""
|
39 |
+
assert not df.empty
|
40 |
+
value_name = "response" if "response" in df.columns else "logloss"
|
41 |
+
|
42 |
+
df1 = df[~df[value_name].isna()]
|
43 |
+
ll = df1['length']
|
44 |
+
xx1 = df1[value_name]
|
45 |
+
return fit_per_length_survival_function(ll, xx1, log_space=True, G=G)
|
46 |
+
|
47 |
+
|
48 |
+
def mark_edits_remove_tags(chunks, tag="edit"):
|
49 |
+
text_chunks = chunks['text']
|
50 |
+
edits = []
|
51 |
+
for i,text in enumerate(text_chunks):
|
52 |
+
chunk_text = re.findall(rf"<{tag}>(.+)</{tag}>", text)
|
53 |
+
if len(chunk_text) > 0:
|
54 |
+
import pdb; pdb.set_trace()
|
55 |
+
chunks['text'][i] = chunk_text[0]
|
56 |
+
chunks['length'][i] -= 2
|
57 |
+
edits.append(True)
|
58 |
+
else:
|
59 |
+
edits.append(False)
|
60 |
+
|
61 |
+
return chunks, edits
|
62 |
+
|
63 |
+
def get_null_data(model_name, topic):
|
64 |
+
data = None
|
65 |
+
try:
|
66 |
+
file = open(f'nullData/{model_name}_{topic}.pkl', 'rb')
|
67 |
+
data = pickle.load(file)
|
68 |
+
except:
|
69 |
+
pass
|
70 |
+
|
71 |
+
return data
|
72 |
+
|
73 |
+
def get_threshold_obj(model_name, topic):
|
74 |
+
threshold = None
|
75 |
+
try:
|
76 |
+
file = open('threshold_obj.pkl', 'rb')
|
77 |
+
threshold_obj = pickle.load(file)
|
78 |
+
threshold = threshold_obj[model_name][topic]
|
79 |
+
except:
|
80 |
+
pass
|
81 |
+
|
82 |
+
return threshold
|
83 |
+
|
84 |
+
def detect_human_text(model_name, topic, text):
|
85 |
+
|
86 |
+
# Get null data
|
87 |
+
print('Get null data')
|
88 |
+
df_null = get_null_data(model_name, topic)
|
89 |
+
if 'num' in df_null.columns:
|
90 |
+
df_null = df_null[df_null.num > 1]
|
91 |
+
|
92 |
+
# Get survival function
|
93 |
+
print('Get survival function')
|
94 |
+
pval_functions = get_survival_function(df_null, G=43)
|
95 |
+
|
96 |
+
min_tokens_per_sentence = 10
|
97 |
+
max_tokens_per_sentence = 100
|
98 |
+
|
99 |
+
# Init model
|
100 |
+
print('Init model')
|
101 |
+
lm_name = 'gpt2-xl' if model_name == 'GPT2XL' else 'microsoft/phi-2'
|
102 |
+
tokenizer = AutoTokenizer.from_pretrained(lm_name)
|
103 |
+
model = AutoModelForCausalLM.from_pretrained(lm_name)
|
104 |
+
|
105 |
+
print('Init PerplexityEvaluator')
|
106 |
+
sentence_detector = PerplexityEvaluator(model, tokenizer)
|
107 |
+
|
108 |
+
if torch.backends.mps.is_available():
|
109 |
+
device = 'mps'
|
110 |
+
elif torch.cuda.is_available():
|
111 |
+
device = 'cuda'
|
112 |
+
else:
|
113 |
+
device = 'cpu'
|
114 |
+
|
115 |
+
print(f'device {device}')
|
116 |
+
model.to(device)
|
117 |
+
|
118 |
+
print('Init DetectLM')
|
119 |
+
detector = DetectLM(sentence_detector, pval_functions,
|
120 |
+
min_len=min_tokens_per_sentence,
|
121 |
+
max_len=max_tokens_per_sentence,
|
122 |
+
length_limit_policy='truncate',
|
123 |
+
HC_type='stbl',
|
124 |
+
ignore_first_sentence= False
|
125 |
+
)
|
126 |
+
|
127 |
+
# Convert text to object
|
128 |
+
print('Analyze text')
|
129 |
+
article_obj = get_article_obj(text)
|
130 |
+
parser = PrepareArticles(article_obj, min_tokens=min_tokens_per_sentence, max_tokens=max_tokens_per_sentence)
|
131 |
+
chunks = parser(combined=False)
|
132 |
+
|
133 |
+
# Go over all the document
|
134 |
+
for i in range(len(chunks['text'])):
|
135 |
+
print(chunks['text'][i])
|
136 |
+
# for p,v in enumerate(chunks['text'][i]):
|
137 |
+
# print(f'{p}: {v}')
|
138 |
+
res = detector(chunks['text'][i], chunks['context'][i], dashboard=None)
|
139 |
+
|
140 |
+
# print(f"Num of Edits (rate) = {np.sum(df['tag'] == '<edit>')} ({edit_rate})")
|
141 |
+
# print(f"HC = {res['HC']}")
|
142 |
+
# print(f"Fisher = {res['fisher']}")
|
143 |
+
# print(f"Fisher (chisquared pvalue) = {res['fisher_pvalue']}")
|
144 |
+
|
145 |
+
results = res['HC']
|
146 |
+
|
147 |
+
threshold = get_threshold_obj(model_name, topic)
|
148 |
+
print(f"threshold: {threshold}, results: {results}")
|
149 |
+
return '1' if results >= threshold else '0'
|
150 |
+
|
151 |
+
# Convert article text into object
|
152 |
+
def get_article_obj(text):
|
153 |
+
# Init article object
|
154 |
+
article_obj = {
|
155 |
+
'sub_titles': [{
|
156 |
+
'sentences': []
|
157 |
+
}]
|
158 |
+
}
|
159 |
+
|
160 |
+
nlp = spacy.load("en_core_web_sm") # Load model
|
161 |
+
|
162 |
+
for line in text.split('\n'):
|
163 |
+
doc = nlp(line) # Analyze text
|
164 |
+
sentences = [sent.text for sent in doc.sents if len(sent) >= 10] # Split it by sentence
|
165 |
+
for sentence in sentences:
|
166 |
+
sentence = re.sub(r' +', ' ', sentence) # Remove duplicate spaces
|
167 |
+
article_obj['sub_titles'][0]['sentences'].append({'sentence': sentence})
|
168 |
+
|
169 |
+
return article_obj
|
nullData/GPT2XL_characters.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75c6c7b757dd7db42e73ae3fea662d5fc871be22d66b2784531c8996e3dfacc7
|
3 |
+
size 3168919
|
nullData/GPT2XL_locations.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6bba9cb0f09b801a43f1c2bfb04f30b9764ed106d7488db7d44abc207579bb6
|
3 |
+
size 3137467
|
nullData/GPT2XL_nature.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbaab4e41c90faf4c4c8cd794b99045947a7aef5b19a65ed6ec2e0678673cd81
|
3 |
+
size 3192531
|
nullData/GPT2XL_video_games_series_movies.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aabd9d0e5fcb4dc98fbec83c6064df0fb168a172c078ef015afaebc0b1e54e39
|
3 |
+
size 3266168
|
nullData/GPT2XL_war.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15805a913be9c9bb34daf6ce47b011b1f8388b708a0435cd23bf5efe886ebf37
|
3 |
+
size 3253367
|
nullData/PHI2_characters.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:585b56afdca713f6d1b152e69aeef58aa66abd5986c0d05363016b571568e2c1
|
3 |
+
size 3168919
|
nullData/PHI2_locations.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09b0801020a2be2ac32355e38ba6efc4b7a6c5bfa2ad3677d2e0fcda56b54cf1
|
3 |
+
size 3137467
|
nullData/PHI2_nature.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b5b787d4cc5f74b882d064a5d58f8de3f456268e1121c78e2b4ba5b5db5a6c9
|
3 |
+
size 3192531
|
nullData/PHI2_video_games_series_movies.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b24cb4d919b42f20153e7b481a67e2a1e7079468af231b3e6219c803829184d2
|
3 |
+
size 3266168
|
nullData/PHI2_war.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:389ec634c5854434f65c1087d6384089a307e721436cca51dba061bcb30baccb
|
3 |
+
size 3253382
|
requirements.txt
ADDED
Binary file (3.04 kB). View file
|
|
src/DetectLM.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from multitest import MultiTest
|
4 |
+
from tqdm import tqdm
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
def truncae_to_max_no_tokens(text, max_no_tokens):
|
9 |
+
return " ".join(text.split()[:max_no_tokens])
|
10 |
+
|
11 |
+
|
12 |
+
class DetectLM(object):
|
13 |
+
def __init__(self, sentence_detection_function, survival_function_per_length,
|
14 |
+
min_len=4, max_len=100, HC_type="stbl",
|
15 |
+
length_limit_policy='truncate', ignore_first_sentence=False):
|
16 |
+
"""
|
17 |
+
Test for the presence of sentences of irregular origin as reflected by the
|
18 |
+
sentence_detection_function. The test is based on the sentence detection function
|
19 |
+
and the P-values obtained from the survival function of the detector's responses.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
----
|
23 |
+
:sentence_detection_function: a function returning the response of the text
|
24 |
+
under the detector. Typically, the response is a logloss value under some language model.
|
25 |
+
:survival_function_per_length: survival_function_per_length(l, x) is the probability of the language
|
26 |
+
model to produce a sentence value as extreme as x or more when the sentence s is the input to
|
27 |
+
the detector. The function is defined for every sentence length l.
|
28 |
+
The detector can also recieve a context c, in which case the input is the pair (s, c).
|
29 |
+
:length_limit_policy: When a sentence exceeds ``max_len``, we can:
|
30 |
+
'truncate': truncate sentence to the maximal length :max_len
|
31 |
+
'ignore': do not evaluate the response and P-value for this sentence
|
32 |
+
'max_available': use the logloss function of the maximal available length
|
33 |
+
:ignore_first_sentence: whether to ignore the first sentence in the document or not. Useful when assuming
|
34 |
+
context of the form previous sentence.
|
35 |
+
"""
|
36 |
+
|
37 |
+
self.survival_function_per_length = survival_function_per_length
|
38 |
+
self.sentence_detector = sentence_detection_function
|
39 |
+
self.min_len = min_len
|
40 |
+
self.max_len = max_len
|
41 |
+
self.length_limit_policy = length_limit_policy
|
42 |
+
self.ignore_first_sentence = ignore_first_sentence
|
43 |
+
self.HC_stbl = True if HC_type == 'stbl' else False
|
44 |
+
|
45 |
+
def _logperp(self, sent: str, context=None) -> float:
|
46 |
+
return float(self.sentence_detector(sent, context))
|
47 |
+
|
48 |
+
def _test_sentence(self, sentence: str, context=None):
|
49 |
+
return self._logperp(sentence, context)
|
50 |
+
|
51 |
+
def _get_length(self, sentence: str):
|
52 |
+
return len(sentence.split())
|
53 |
+
|
54 |
+
def _test_response(self, response: float, length: int):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
response: sentence logloss
|
58 |
+
length: sentence length in tokens
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
pvals: P-value of the logloss of the sentence
|
62 |
+
comments: comment on the P-value
|
63 |
+
"""
|
64 |
+
if self.min_len <= length:
|
65 |
+
comment = "OK"
|
66 |
+
if length > self.max_len: # in case length exceeds specifications...
|
67 |
+
if self.length_limit_policy == 'truncate':
|
68 |
+
length = self.max_len
|
69 |
+
comment = f"truncated to {self.max_len} tokens"
|
70 |
+
elif self.length_limit_policy == 'ignore':
|
71 |
+
comment = "ignored (above maximum limit)"
|
72 |
+
return np.nan, np.nan, comment
|
73 |
+
elif self.length_limit_policy == 'max_available':
|
74 |
+
comment = "exceeding length limit; resorting to max-available length"
|
75 |
+
length = self.max_len
|
76 |
+
pval = self.survival_function_per_length(length, response)
|
77 |
+
assert pval >= 0, "Negative P-value. Something is wrong."
|
78 |
+
return dict(response=response,
|
79 |
+
pvalue=pval,
|
80 |
+
length=length,
|
81 |
+
comment=comment)
|
82 |
+
else:
|
83 |
+
comment = "ignored (below minimal length)"
|
84 |
+
return dict(response=response,
|
85 |
+
pvalue=np.nan,
|
86 |
+
length=length,
|
87 |
+
comment=comment)
|
88 |
+
|
89 |
+
def _get_pvals(self, responses: list, lengths: list) -> tuple:
|
90 |
+
pvals = []
|
91 |
+
comments = []
|
92 |
+
for response, length in zip(responses, lengths):
|
93 |
+
r = self._test_response(response, length)
|
94 |
+
pvals.append(float(r['pvalue']))
|
95 |
+
comments.append(r['comment'])
|
96 |
+
return pvals, comments
|
97 |
+
|
98 |
+
|
99 |
+
def _get_responses(self, sentences: list, contexts: list) -> list:
|
100 |
+
"""
|
101 |
+
Compute response and length of a text sentence
|
102 |
+
"""
|
103 |
+
assert len(sentences) == len(contexts)
|
104 |
+
|
105 |
+
responses = []
|
106 |
+
lengths = []
|
107 |
+
for sent, ctx in tqdm(zip(sentences, contexts)):
|
108 |
+
logging.debug(f"Testing sentence: {sent} | context: {ctx}")
|
109 |
+
length = self._get_length(sent)
|
110 |
+
if self.length_limit_policy == 'truncate':
|
111 |
+
sent = truncae_to_max_no_tokens(sent, self.max_len)
|
112 |
+
if length == 1:
|
113 |
+
logging.warning(f"Sentence {sent} is too short. Skipping.")
|
114 |
+
responses.append(np.nan)
|
115 |
+
continue
|
116 |
+
try:
|
117 |
+
responses.append(self._test_sentence(sent, ctx))
|
118 |
+
except:
|
119 |
+
# something unusual happened...
|
120 |
+
import pdb; pdb.set_trace()
|
121 |
+
lengths.append(length)
|
122 |
+
return responses, lengths
|
123 |
+
|
124 |
+
def get_pvals(self, sentences: list, contexts: list) -> tuple:
|
125 |
+
"""
|
126 |
+
logloss test of every (sentence, context) pair
|
127 |
+
"""
|
128 |
+
assert len(sentences) == len(contexts)
|
129 |
+
|
130 |
+
responses, lengths = self._get_responses(sentences, contexts)
|
131 |
+
pvals, comments = self._get_pvals(responses, lengths)
|
132 |
+
|
133 |
+
return pvals, responses, comments
|
134 |
+
|
135 |
+
|
136 |
+
def testHC(self, sentences: list) -> float:
|
137 |
+
pvals = np.array(self.get_pvals(sentences)[1])
|
138 |
+
mt = MultiTest(pvals, stbl=self.HC_stbl)
|
139 |
+
return mt.hc(gamma=0.4)[0]
|
140 |
+
|
141 |
+
def testFisher(self, sentences: list) -> dict:
|
142 |
+
pvals = np.array(self.get_pvals(sentences)[1])
|
143 |
+
print(pvals)
|
144 |
+
mt = MultiTest(pvals, stbl=self.HC_stbl)
|
145 |
+
return dict(zip(['Fn', 'pvalue'], mt.fisher()))
|
146 |
+
|
147 |
+
def _test_chunked_doc(self, lo_chunks: list, lo_contexts: list) -> tuple:
|
148 |
+
pvals, responses, comments = self.get_pvals(lo_chunks, lo_contexts)
|
149 |
+
if self.ignore_first_sentence:
|
150 |
+
pvals[0] = np.nan
|
151 |
+
logging.info('Ignoring the first sentence.')
|
152 |
+
comments[0] = "ignored (first sentence)"
|
153 |
+
|
154 |
+
df = pd.DataFrame({'sentence': lo_chunks, 'response': responses, 'pvalue': pvals,
|
155 |
+
'context': lo_contexts, 'comment': comments},
|
156 |
+
index=range(len(lo_chunks)))
|
157 |
+
df_test = df[~df.pvalue.isna()]
|
158 |
+
if df_test.empty:
|
159 |
+
logging.warning('No valid chunks to test.')
|
160 |
+
return None, df
|
161 |
+
return MultiTest(df_test.pvalue, stbl=self.HC_stbl), df
|
162 |
+
|
163 |
+
def test_chunked_doc(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
|
164 |
+
mt, df = self._test_chunked_doc(lo_chunks, lo_contexts)
|
165 |
+
if mt is None:
|
166 |
+
hc = np.nan
|
167 |
+
fisher = (np.nan, np.nan)
|
168 |
+
df['mask'] = pd.NA
|
169 |
+
else:
|
170 |
+
hc, hct = mt.hc(gamma=0.4)
|
171 |
+
fisher = mt.fisher()
|
172 |
+
df['mask'] = df['pvalue'] <= hct
|
173 |
+
if dashboard:
|
174 |
+
mt.hc_dashboard(gamma=0.4)
|
175 |
+
return dict(sentences=df, HC=hc, fisher=fisher[0], fisher_pvalue=fisher[1])
|
176 |
+
|
177 |
+
def __call__(self, lo_chunks: list, lo_contexts: list, dashboard=False) -> dict:
|
178 |
+
return self.test_chunked_doc(lo_chunks, lo_contexts, dashboard=dashboard)
|
src/HC_survival_function.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script computes the survival function of the HC statistic for a given sample size n.
|
3 |
+
The survival function is computed using a simulation of the null distribution of the HC statistic.
|
4 |
+
We use the simulation results to fit a bivariate function of the form Pr[HC >= x | n] = f(n, x).
|
5 |
+
The simulation results are saved in a file named HC_null_sim_results.csv.
|
6 |
+
use function get_HC_survival_function to load the bivariate function or simulate the distribution.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from multitest import MultiTest
|
12 |
+
from tqdm import tqdm
|
13 |
+
from scipy.interpolate import RectBivariateSpline
|
14 |
+
from src.fit_survival_function import fit_survival_func
|
15 |
+
import logging
|
16 |
+
|
17 |
+
HC_NULL_SIM_FILE = "HC_null_sim_results.csv"
|
18 |
+
STBL = True
|
19 |
+
NN = [25, 50, 75, 100, 125, 150, 200, 250, 300, 400, 500] # values of n to simulate
|
20 |
+
|
21 |
+
def get_HC_survival_function(HC_null_sim_file, log_space=True, nMonte=10000, STBL=True):
|
22 |
+
|
23 |
+
xx = {}
|
24 |
+
if HC_null_sim_file is None:
|
25 |
+
logging.info("Simulated HC null values file was not provided.")
|
26 |
+
for n in tqdm(NN):
|
27 |
+
logging.info(f"Simulating HC null values for n={n}...")
|
28 |
+
yy = np.zeros(nMonte)
|
29 |
+
for j in range(nMonte):
|
30 |
+
uu = np.random.rand(n)
|
31 |
+
mt = MultiTest(uu, stbl=STBL)
|
32 |
+
yy[j] = mt.hc()[0]
|
33 |
+
xx[n] = yy
|
34 |
+
nn = NN # Idan
|
35 |
+
else:
|
36 |
+
logging.info(f"Loading HC null values from {HC_null_sim_file}...")
|
37 |
+
df = pd.read_csv(HC_null_sim_file, index_col=0)
|
38 |
+
for n in df.index:
|
39 |
+
xx[n] = df.loc[n]
|
40 |
+
nn = df.index.tolist()
|
41 |
+
|
42 |
+
xx0 = np.linspace(-1, 10, 57)
|
43 |
+
zz = []
|
44 |
+
for n in nn:
|
45 |
+
univariate_survival_func = fit_survival_func(xx[n], log_space=log_space)
|
46 |
+
zz.append(univariate_survival_func(xx0))
|
47 |
+
|
48 |
+
func_log = RectBivariateSpline(np.array(nn), xx0, np.vstack(zz))
|
49 |
+
|
50 |
+
if log_space:
|
51 |
+
def func(x, y):
|
52 |
+
return np.exp(-func_log(x,y))
|
53 |
+
return func
|
54 |
+
else:
|
55 |
+
return func_log
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
func = get_HC_survival_function(HC_null_sim_file=HC_NULL_SIM_FILE, STBL=STBL)
|
60 |
+
print("Pr[HC >= 3 |n=50] = ", func(50, 3)[0][0]) # 9.680113e-05
|
61 |
+
print("Pr[HC >= 3 |n=100] = ", func(100, 3)[0][0]) # 0.0002335
|
62 |
+
print("Pr[HC >= 3 |n=200] = ", func(200, 3)[0][0]) # 0.00103771
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
main()
|
src/PerplexityEvaluator.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class PerplexityEvaluator(object):
|
4 |
+
def __init__(self, model, tokenizer, ignore_index=-1):
|
5 |
+
self.model = model
|
6 |
+
self.tokenizer = tokenizer
|
7 |
+
self.ignore_index = ignore_index
|
8 |
+
|
9 |
+
def __call__(self, text, context=None):
|
10 |
+
return self.log_perplexity(text, context)
|
11 |
+
|
12 |
+
def log_perplexity(self, text, context=None):
|
13 |
+
"""
|
14 |
+
Evaluate log perplexity of text with respect to the language model
|
15 |
+
based on the context
|
16 |
+
|
17 |
+
:param text:
|
18 |
+
:param context:
|
19 |
+
:return:
|
20 |
+
"""
|
21 |
+
device = self.model.device
|
22 |
+
text_ids = self.tokenizer(text, return_tensors='pt')
|
23 |
+
if context:
|
24 |
+
context_ids = self.tokenizer(context, return_tensors='pt')
|
25 |
+
input_ids = torch.concatenate([context_ids['input_ids'], text_ids['input_ids']], axis=1)
|
26 |
+
labels = torch.concatenate([torch.ones_like(context_ids['input_ids']) * self.ignore_index,
|
27 |
+
text_ids['input_ids']], axis=1)
|
28 |
+
print("Warning, need to remove context length when reporting lppx")
|
29 |
+
else:
|
30 |
+
input_ids = text_ids['input_ids']
|
31 |
+
labels = input_ids
|
32 |
+
|
33 |
+
loss = self.model(input_ids=input_ids.to(device), labels=labels.to(device)).loss
|
34 |
+
return loss.cpu().detach().numpy()
|
src/PrepareArticles.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
|
5 |
+
class PrepareArticles(object):
|
6 |
+
"""
|
7 |
+
Parse preprocessed data from csv
|
8 |
+
|
9 |
+
This information is needed for evaluating log-perplexity of the text with respect to a language model
|
10 |
+
and later on to test the likelihood that the sentence was sampled from the model with the relevant context.
|
11 |
+
"""
|
12 |
+
def __init__(self, article_obj, get_edits=False, min_tokens=10, max_tokens=100, max_sentences=None):
|
13 |
+
self.article_obj = article_obj
|
14 |
+
self.min_tokens = min_tokens
|
15 |
+
self.max_tokens = max_tokens
|
16 |
+
self.get_edits = get_edits
|
17 |
+
self.max_sentences = max_sentences
|
18 |
+
|
19 |
+
def __call__(self, combined=True):
|
20 |
+
return self.parse_dataset(combined)
|
21 |
+
|
22 |
+
def parse_dataset(self, combined=True):
|
23 |
+
|
24 |
+
texts = []
|
25 |
+
lengths = []
|
26 |
+
contexts = []
|
27 |
+
tags = []
|
28 |
+
|
29 |
+
current_texts = []
|
30 |
+
current_lengths = []
|
31 |
+
current_contexts = []
|
32 |
+
current_tags = []
|
33 |
+
exceeded_max_sentences = False
|
34 |
+
|
35 |
+
for sub_title in self.article_obj['sub_titles']: # For each sub title
|
36 |
+
for sentence in sub_title['sentences']: # Go over each sentence
|
37 |
+
sentence_size = len(sentence['sentence'].split())
|
38 |
+
if sentence_size >= self.min_tokens and sentence_size <= self.max_tokens:
|
39 |
+
current_texts.append(sentence['sentence'])
|
40 |
+
current_lengths.append(len(sentence['sentence'].split())) # Number of tokens
|
41 |
+
current_contexts.append(sentence['context'] if 'context' in sentence else None)
|
42 |
+
current_tags.append('no edits')
|
43 |
+
|
44 |
+
# If get_edits and has edited sentence save it
|
45 |
+
if self.get_edits and 'alternative' in sentence and len(sentence['alternative'].split()) >= self.min_tokens and len(sentence['alternative'].split()) <= self.max_tokens:
|
46 |
+
current_texts.append(sentence['alternative'])
|
47 |
+
current_lengths.append(len(sentence['alternative'].split()))
|
48 |
+
current_contexts.append(sentence['alternative_context'] if 'alternative_context' in sentence else None)
|
49 |
+
current_tags.append('<edit>')
|
50 |
+
if self.max_sentences and len(current_texts) >= self.max_sentences:
|
51 |
+
exceeded_max_sentences = True
|
52 |
+
break
|
53 |
+
# return {'text': np.array(texts, dtype=object), 'length': np.array(lengths, dtype=object), 'context': np.array(contexts, dtype=object), 'tag': np.array(tags, dtype=object),
|
54 |
+
# 'number_in_par': np.arange(1,1+len(texts))}
|
55 |
+
if exceeded_max_sentences:
|
56 |
+
break
|
57 |
+
|
58 |
+
# If exceede max sentences only if self.max_sentences is not None
|
59 |
+
if (self.max_sentences and exceeded_max_sentences) or (not self.max_sentences):
|
60 |
+
# If combined, combine the data
|
61 |
+
if combined:
|
62 |
+
texts = texts + current_texts
|
63 |
+
lengths = lengths + current_lengths
|
64 |
+
contexts = contexts + current_contexts
|
65 |
+
tags = tags + current_tags
|
66 |
+
else:
|
67 |
+
texts.append(np.array(current_texts))
|
68 |
+
lengths.append(np.array(current_lengths))
|
69 |
+
contexts.append(np.array(current_contexts))
|
70 |
+
tags.append(np.array(current_tags))
|
71 |
+
|
72 |
+
return {'text': np.array(texts, dtype=object), 'length': np.array(lengths, dtype=object), 'context': np.array(contexts, dtype=object), 'tag': np.array(tags, dtype=object),
|
73 |
+
'number_in_par': np.arange(1,1+len(texts))}
|
74 |
+
|
src/PrepareSentenceContext.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import spacy
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
from src.SentenceParser import SentenceParser
|
6 |
+
|
7 |
+
class PrepareSentenceContext(object):
|
8 |
+
"""
|
9 |
+
Parse text and extract length and context information
|
10 |
+
|
11 |
+
This information is needed for evaluating log-perplexity of the text with respect to a language model
|
12 |
+
and later on to test the likelihood that the sentence was sampled from the model with the relevant context.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, sentence_parser='spacy', context_policy=None, context=None):
|
16 |
+
if sentence_parser == 'spacy':
|
17 |
+
self.nlp = spacy.load("en_core_web_sm", disable=["tagger", "attribute_ruler", "lemmatizer", "ner"])
|
18 |
+
if sentence_parser == 'regex':
|
19 |
+
logging.warning("Regex-based parser is not good at breaking sentences like 'Dr. Stone', etc.")
|
20 |
+
self.nlp = SentenceParser()
|
21 |
+
|
22 |
+
self.sentence_parser_name = sentence_parser
|
23 |
+
|
24 |
+
self.context_policy = context_policy
|
25 |
+
self.context = context
|
26 |
+
|
27 |
+
def __call__(self, text):
|
28 |
+
return self.parse_sentences(text)
|
29 |
+
|
30 |
+
def parse_sentences(self, text):
|
31 |
+
pattern_close = r"(.*?)</edit>"
|
32 |
+
pattern_open = r"<edit>(.*?)"
|
33 |
+
MIN_TOKEN_LEN = 3
|
34 |
+
|
35 |
+
texts = []
|
36 |
+
tags = []
|
37 |
+
lengths = []
|
38 |
+
contexts = []
|
39 |
+
|
40 |
+
def update_sent(sent_text, tag, sent_length):
|
41 |
+
texts.append(sent_text)
|
42 |
+
tags.append(tag)
|
43 |
+
lengths.append(sent_length)
|
44 |
+
if self.context is not None:
|
45 |
+
context = self.context
|
46 |
+
elif self.context_policy is None:
|
47 |
+
context = None
|
48 |
+
elif self.context_policy == 'previous_sentence' and len(texts) > 0:
|
49 |
+
context = texts[-1]
|
50 |
+
else:
|
51 |
+
context = None
|
52 |
+
contexts.append(context)
|
53 |
+
|
54 |
+
curr_tag = None
|
55 |
+
parsed = self.nlp(text)
|
56 |
+
for s in parsed.sents:
|
57 |
+
prev_tag = curr_tag
|
58 |
+
matches_close = re.findall(pattern_close, s.text)
|
59 |
+
matches_open = re.findall(pattern_open, s.text)
|
60 |
+
matches_between = re.findall(r"<edit>(.*?)</edit>", s.text)
|
61 |
+
|
62 |
+
logging.debug(f"Current sentence: {s.text}")
|
63 |
+
logging.debug(f"Matches open: {matches_open}")
|
64 |
+
logging.debug(f"Matches close: {matches_close}")
|
65 |
+
logging.debug(f"Matches between: {matches_between}")
|
66 |
+
if len(matches_close)>0 and len(matches_open)>0:
|
67 |
+
logging.debug("Found an opening and a closing tag in the same sentence.")
|
68 |
+
if prev_tag is None and len(matches_open[0]) >= MIN_TOKEN_LEN:
|
69 |
+
logging.debug("Openning followed by closing with some text in between.")
|
70 |
+
update_sent(matches_open[0], "<edit>", len(s)-2)
|
71 |
+
curr_tag = None
|
72 |
+
if prev_tag == "<edit>" and len(matches_close[0]) >= MIN_TOKEN_LEN:
|
73 |
+
logging.warning(f"Wierd case: closing/openning followed by openning in sentence {len(texts)}")
|
74 |
+
update_sent(matches_close[0], prev_tag, len(s)-1)
|
75 |
+
curr_tag = None
|
76 |
+
if prev_tag == "</edit>":
|
77 |
+
logging.debug("Closing followed by openning.")
|
78 |
+
curr_tag = "<edit>"
|
79 |
+
if len(matches_between[0]) > MIN_TOKEN_LEN:
|
80 |
+
update_sent(matches_between[0], None, len(s)-2)
|
81 |
+
elif len(matches_open) > 0:
|
82 |
+
curr_tag = "<edit>"
|
83 |
+
assert prev_tag is None, f"Found an opening tag without a closing tag in sentence num. {len(texts)}"
|
84 |
+
if len(matches_open[0]) >= MIN_TOKEN_LEN:
|
85 |
+
# text and tag are in the same sentence
|
86 |
+
sent_text = matches_open[0]
|
87 |
+
update_sent(sent_text, curr_tag, len(s)-1)
|
88 |
+
elif len(matches_close) > 0:
|
89 |
+
curr_tag = "</edit>"
|
90 |
+
assert prev_tag == "<edit>", f"Found a closing tag without an opening tag in sentence num. {len(texts)}"
|
91 |
+
if len(matches_close[0]) >= MIN_TOKEN_LEN:
|
92 |
+
# text and tag are in the same sentence
|
93 |
+
update_sent(matches_close[0], prev_tag, len(s)-1)
|
94 |
+
curr_tag = None
|
95 |
+
else:
|
96 |
+
#if len(matches_close)==0 and len(matches_open)==0:
|
97 |
+
# no tag
|
98 |
+
update_sent(s.text, curr_tag, len(s))
|
99 |
+
return {'text': texts, 'length': lengths, 'context': contexts, 'tag': tags,
|
100 |
+
'number_in_par': np.arange(1,1+len(texts))}
|
101 |
+
|
102 |
+
def REMOVE_parse_sentences(self, text):
|
103 |
+
texts = []
|
104 |
+
contexts = []
|
105 |
+
lengths = []
|
106 |
+
tags = []
|
107 |
+
num_in_par = []
|
108 |
+
previous = None
|
109 |
+
|
110 |
+
text = re.sub("(</?[a-zA-Z0-9 ]+>\.?)\s+", r"\1.\n", text) # to make sure that tags are in separate sentences
|
111 |
+
#text = re.sub("(</[a-zA-Z0-9 ]+>\.?)\s+", r"\n\1.\n", text) # to make sure that tags are in separate sentences
|
112 |
+
|
113 |
+
parsed = self.nlp(text)
|
114 |
+
|
115 |
+
running_sent_num = 0
|
116 |
+
curr_tag = None
|
117 |
+
for i, sent in enumerate(parsed.sents):
|
118 |
+
# Here we try to track HTML-like tags. There might be
|
119 |
+
# some issues because spacy sentence parser has unexpected behavior when it comes to newlines
|
120 |
+
all_tags = re.findall(r"(</?[a-zA-Z0-9 ]+>)", str(sent))
|
121 |
+
if len(all_tags) > 1:
|
122 |
+
logging.error(f"More than one tag in sentence {i}: {all_tags}")
|
123 |
+
exit(1)
|
124 |
+
if len(all_tags) == 1:
|
125 |
+
tag = all_tags[0]
|
126 |
+
if tag[:2] == '</': # a closing tag
|
127 |
+
if curr_tag is None:
|
128 |
+
logging.warning(f"Closing tag without an opening tag in sentence {i}: {sent}")
|
129 |
+
else:
|
130 |
+
curr_tag = None
|
131 |
+
else:
|
132 |
+
if curr_tag is not None:
|
133 |
+
logging.warning(f"Opening tag without a closing tag in sentence {i}: {sent}")
|
134 |
+
else:
|
135 |
+
curr_tag = tag
|
136 |
+
else: # if text is not a tag
|
137 |
+
sent_text = str(sent)
|
138 |
+
sent_length = len(sent)
|
139 |
+
|
140 |
+
texts.append(sent_text)
|
141 |
+
running_sent_num += 1
|
142 |
+
num_in_par.append(running_sent_num)
|
143 |
+
tags.append(curr_tag)
|
144 |
+
lengths.append(sent_length)
|
145 |
+
|
146 |
+
if self.context is not None:
|
147 |
+
context = self.context
|
148 |
+
elif self.context_policy is None:
|
149 |
+
context = None
|
150 |
+
elif self.context_policy == 'previous_sentence':
|
151 |
+
context = previous
|
152 |
+
previous = sent_text
|
153 |
+
else:
|
154 |
+
context = None
|
155 |
+
|
156 |
+
contexts.append(context)
|
157 |
+
return {'text': texts, 'length': lengths, 'context': contexts, 'tag': tags,
|
158 |
+
'number_in_par': num_in_par}
|
src/SentenceParser.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
class Sentence(object):
|
5 |
+
def __init__(self, text):
|
6 |
+
self.text = text
|
7 |
+
self.tokens = text.split()
|
8 |
+
|
9 |
+
def __len__(self):
|
10 |
+
return len(self.tokens)
|
11 |
+
|
12 |
+
class Sentences(object):
|
13 |
+
def __init__(self, text):
|
14 |
+
def iterate(text):
|
15 |
+
for s in re.split(r"\n", text):
|
16 |
+
yield s
|
17 |
+
self.sents = iterate(text)
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.sents)
|
21 |
+
|
22 |
+
class SentenceParser(object):
|
23 |
+
"""
|
24 |
+
Iterate over the text column of a dataframe
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
self.sents = None
|
29 |
+
|
30 |
+
def __call__(self, text):
|
31 |
+
return Sentences(text)
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/DetectLM.cpython-310.pyc
ADDED
Binary file (6.96 kB). View file
|
|
src/__pycache__/DetectLM.cpython-38.pyc
ADDED
Binary file (6.99 kB). View file
|
|
src/__pycache__/HC_survival_function.cpython-310.pyc
ADDED
Binary file (2.45 kB). View file
|
|
src/__pycache__/HC_survival_function.cpython-38.pyc
ADDED
Binary file (2.45 kB). View file
|
|
src/__pycache__/PerplexityEvaluator.cpython-310.pyc
ADDED
Binary file (1.51 kB). View file
|
|
src/__pycache__/PerplexityEvaluator.cpython-312.pyc
ADDED
Binary file (2.23 kB). View file
|
|
src/__pycache__/PerplexityEvaluator.cpython-38.pyc
ADDED
Binary file (1.49 kB). View file
|
|
src/__pycache__/PrepareArticles.cpython-310.pyc
ADDED
Binary file (2.23 kB). View file
|
|
src/__pycache__/PrepareArticles.cpython-38.pyc
ADDED
Binary file (2.39 kB). View file
|
|
src/__pycache__/PrepareSentenceContext.cpython-310.pyc
ADDED
Binary file (4.49 kB). View file
|
|
src/__pycache__/PrepareSentenceContext.cpython-38.pyc
ADDED
Binary file (4.52 kB). View file
|
|
src/__pycache__/SentenceParser.cpython-310.pyc
ADDED
Binary file (1.62 kB). View file
|
|
src/__pycache__/SentenceParser.cpython-38.pyc
ADDED
Binary file (1.63 kB). View file
|
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (156 Bytes). View file
|
|
src/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (138 Bytes). View file
|
|
src/__pycache__/fit_survival_function.cpython-310.pyc
ADDED
Binary file (2.28 kB). View file
|
|
src/__pycache__/fit_survival_function.cpython-38.pyc
ADDED
Binary file (2.29 kB). View file
|
|
src/dataset_loaders.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
|
3 |
+
SEED = 42
|
4 |
+
|
5 |
+
|
6 |
+
def get_dataset(name: str, machine_field, human_field, iterable=False,
|
7 |
+
text_field=None, shuffle=False, main_split='train'):
|
8 |
+
dataset = load_dataset(name)[main_split]
|
9 |
+
ds = dataset.rename_columns({human_field: 'human_text', machine_field: 'machine_text'})
|
10 |
+
if 'id' not in ds.features:
|
11 |
+
ids = list(range(len(ds)))
|
12 |
+
ds = ds.add_column("id", ids)
|
13 |
+
if text_field:
|
14 |
+
ds = ds.rename_columns({text_field: 'text'})
|
15 |
+
|
16 |
+
if iterable:
|
17 |
+
ds = ds.to_iterable_dataset()
|
18 |
+
if shuffle:
|
19 |
+
return ds.shuffle(seed=SEED)
|
20 |
+
else:
|
21 |
+
return ds
|
22 |
+
|
23 |
+
|
24 |
+
def get_text_from_wiki_dataset(shuffle=False, text_field=None):
|
25 |
+
return get_dataset(name="aadityaubhat/GPT-wiki-intro", machine_field='generated_intro',
|
26 |
+
human_field="wiki_intro", shuffle=shuffle, text_field=text_field)
|
27 |
+
|
28 |
+
|
29 |
+
def get_text_from_wiki_long_dataset(shuffle=False, text_field=None):
|
30 |
+
return get_dataset(name="alonkipnis/wiki-intro-long", machine_field='generated_intro',
|
31 |
+
human_field="wiki_intro", shuffle=shuffle, text_field=text_field)
|
32 |
+
|
33 |
+
|
34 |
+
def get_text_from_wiki_long_dataset_local(shuffle=False, text_field=None, iterable=False):
|
35 |
+
"""
|
36 |
+
A version of wiki_intro dataset with at least 15 sentences per generated article
|
37 |
+
"""
|
38 |
+
dataset = load_dataset("alonkipnis/wiki-intro-long")
|
39 |
+
ds = dataset.rename_columns({"wiki_intro": 'human_text', "generated_intro": 'machine_text'})
|
40 |
+
if text_field:
|
41 |
+
ds = ds.rename_columns({text_field: 'text'})
|
42 |
+
if iterable:
|
43 |
+
ds = ds.to_iterable_dataset()
|
44 |
+
if shuffle:
|
45 |
+
return ds.shuffle(seed=SEED)
|
46 |
+
else:
|
47 |
+
return ds
|
48 |
+
|
49 |
+
|
50 |
+
def get_text_from_chatgpt_news_long_dataset_local(shuffle=False, text_field=None, iterable=False):
|
51 |
+
"""
|
52 |
+
A version of chatgpt-news-articles dataset with at least 15 sentences per generated article
|
53 |
+
Only 'train' split is included
|
54 |
+
"""
|
55 |
+
dataset = load_dataset("alonkipnis/news-chatgpt-long")
|
56 |
+
ds = dataset.rename_columns({"article": 'human_text', "chatgpt": 'machine_text'})
|
57 |
+
if text_field:
|
58 |
+
ds = ds.rename_columns({text_field: 'text'})
|
59 |
+
if iterable:
|
60 |
+
ds = ds.to_iterable_dataset()
|
61 |
+
if shuffle:
|
62 |
+
return ds.shuffle(seed=SEED)
|
63 |
+
else:
|
64 |
+
return ds
|
65 |
+
|
66 |
+
def get_text_from_chatgpt_abstracts_dataset(shuffle=False, text_field=None):
|
67 |
+
return get_dataset(name="NicolaiSivesind/ChatGPT-Research-Abstracts", machine_field="generated_abstract",
|
68 |
+
human_field="real_abstract", shuffle=shuffle, text_field=text_field)
|
69 |
+
|
70 |
+
def get_text_from_chatgpt_news_long_dataset(shuffle=False, text_field=None):
|
71 |
+
return get_dataset(name="alonkipnis/news-chatgpt-long", machine_field='chatgpt',
|
72 |
+
human_field="article", shuffle=shuffle, text_field=text_field)
|
73 |
+
|
74 |
+
|
75 |
+
def get_text_from_chatgpt_news_dataset(shuffle=False, text_field=None):
|
76 |
+
return get_dataset(name="isarth/chatgpt-news-articles", machine_field='chatgpt',
|
77 |
+
human_field="article", shuffle=shuffle, text_field=text_field)
|
78 |
+
|
79 |
+
|
80 |
+
def get_text_from_wikibio_dataset(shuffle=False, text_field=None):
|
81 |
+
return get_dataset(name="potsawee/wiki_bio_gpt3_hallucination", machine_field='gpt3_text',
|
82 |
+
human_field="wiki_bio_text", shuffle=shuffle, text_field=text_field, main_split='evaluation')
|
83 |
+
|
84 |
+
## New datasets (22/5/2023)
|
85 |
+
def get_text_from_alpaca_gpt4_dataset(shuffle=False, text_field=None):
|
86 |
+
return get_dataset(name="polyware-ai/alpaca-gpt4-cleaned", machine_field='output',
|
87 |
+
human_field="instruction", shuffle=shuffle, text_field=text_field)
|
src/fit_survival_function.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script to read log-loss data of many sentences and characterize the empirical distribution.
|
3 |
+
We also report the mean log-loss as a function of sentence length
|
4 |
+
"""
|
5 |
+
from scipy.interpolate import RectBivariateSpline, interp1d
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def fit_survival_func(xx, log_space=True):
|
9 |
+
"""
|
10 |
+
Returns an estimated survival function to the data in :xx: using
|
11 |
+
interpolation.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
:xx: data
|
15 |
+
:log_space: indicates whether fitting is in log space or not.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
univariate function
|
19 |
+
"""
|
20 |
+
assert len(xx) > 0
|
21 |
+
|
22 |
+
eps = 1 / len(xx)
|
23 |
+
inf = 1 / eps
|
24 |
+
|
25 |
+
sxx = np.sort(xx)
|
26 |
+
qq = np.mean(np.expand_dims(sxx,1) >= sxx, 0)
|
27 |
+
|
28 |
+
if log_space:
|
29 |
+
qq = -np.log(qq)
|
30 |
+
|
31 |
+
|
32 |
+
if log_space:
|
33 |
+
return interp1d(sxx, qq, fill_value=(0 , np.log(inf)), bounds_error=False)
|
34 |
+
else:
|
35 |
+
return interp1d(sxx, qq, fill_value=(1 , 0), bounds_error=False)
|
36 |
+
|
37 |
+
|
38 |
+
def fit_per_length_survival_function(lengths, xx, G=501, log_space=True):
|
39 |
+
"""
|
40 |
+
Returns a survival function for every sentence length in tokens.
|
41 |
+
Use 2D interpolation over the empirical survival function of the pairs (length, x)
|
42 |
+
|
43 |
+
Args:
|
44 |
+
:lengths:, :xx:, 1-D arrays
|
45 |
+
:G: number of grid points to use in the interpolation in the xx dimension
|
46 |
+
:log_space: indicates whether result is in log space or not.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
bivariate function (length, x) -> [0,1]
|
50 |
+
"""
|
51 |
+
|
52 |
+
assert len(lengths) == len(xx)
|
53 |
+
|
54 |
+
min_tokens_per_sentence = lengths.min()
|
55 |
+
max_tokens_per_sentence = lengths.max()
|
56 |
+
ll = np.arange(min_tokens_per_sentence, max_tokens_per_sentence)
|
57 |
+
|
58 |
+
ppx_min_val = xx.min()
|
59 |
+
ppx_max_val = xx.max()
|
60 |
+
xx0 = np.linspace(ppx_min_val, ppx_max_val, G)
|
61 |
+
|
62 |
+
ll_valid = []
|
63 |
+
zz = []
|
64 |
+
for l in ll:
|
65 |
+
xx1 = xx[lengths == l]
|
66 |
+
if len(xx1) > 1:
|
67 |
+
univariate_survival_func = fit_survival_func(xx1, log_space=log_space)
|
68 |
+
ll_valid.append(l)
|
69 |
+
zz.append(univariate_survival_func(xx0))
|
70 |
+
|
71 |
+
func = RectBivariateSpline(np.array(ll_valid), xx0, np.vstack(zz))
|
72 |
+
if log_space:
|
73 |
+
def func2d(x, y):
|
74 |
+
return np.exp(-func(x,y))
|
75 |
+
return func2d
|
76 |
+
else:
|
77 |
+
return func
|
78 |
+
|
79 |
+
|
80 |
+
# import pickle
|
81 |
+
# import pandas as pd
|
82 |
+
# df = pd.read_csv('D:\\.Idan\\转讜讗专 砖谞讬\\转讝讛\\detectLM\\article_null.csv')
|
83 |
+
# LOGLOSS_PVAL_FUNC_FILE = 'D:\.Idan\转讜讗专 砖谞讬\转讝讛\detectLM\example\logloss_pval_function.pkl'
|
84 |
+
# LOGLOSS_PVAL_FUNC_FILE_TEST = 'D:\.Idan\转讜讗专 砖谞讬\转讝讛\detectLM\example\logloss_pval_function_test.pkl'
|
85 |
+
# with open(LOGLOSS_PVAL_FUNC_FILE, 'wb') as handle:
|
86 |
+
# pickle.dump(fit_per_length_survival_function(df['length'].values, df['response'].values), handle, protocol=pickle.HIGHEST_PROTOCOL)
|
87 |
+
|
88 |
+
# with open(LOGLOSS_PVAL_FUNC_FILE, 'rb') as f:
|
89 |
+
# data = pickle.load(f)
|
90 |
+
# print(data)
|
91 |
+
|
92 |
+
# with open(LOGLOSS_PVAL_FUNC_FILE_TEST, 'rb') as f:
|
93 |
+
# data = pickle.load(f)
|
94 |
+
# print(data)
|
threshold_obj.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b9b1dfc4fc552a4c975ebe1f05a5140bee30fc8231fd4b4eba1dcf4082d127a
|
3 |
+
size 208
|