detect-edits-in-ai-generated-text / human_text_detect.py
idkash1's picture
Update human_text_detect.py
a4f1f33 verified
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
import numpy as np
import pickle
from src.DetectLM import DetectLM
from src.PerplexityEvaluator import PerplexityEvaluator
from src.PrepareArticles import PrepareArticles
from src.fit_survival_function import fit_per_length_survival_function
from glob import glob
import spacy
import re
import os
import gc
logging.basicConfig(level=logging.INFO)
def read_all_csv_files(pattern):
df = pd.DataFrame()
print(pattern)
for f in glob(pattern):
df = pd.concat([df, pd.read_csv(f)])
return df
def get_survival_function(df, G=101):
"""
Returns a survival function for every sentence length in tokens.
Args:
:df: data frame with columns 'response' and 'length'
:G: number of interpolation points
Return:
bivariate function (length, responce) -> (0,1)
"""
assert not df.empty
value_name = "response" if "response" in df.columns else "logloss"
df1 = df[~df[value_name].isna()]
ll = df1['length']
xx1 = df1[value_name]
return fit_per_length_survival_function(ll, xx1, log_space=True, G=G)
def mark_edits_remove_tags(chunks, tag="edit"):
text_chunks = chunks['text']
edits = []
for i,text in enumerate(text_chunks):
chunk_text = re.findall(rf"<{tag}>(.+)</{tag}>", text)
if len(chunk_text) > 0:
import pdb; pdb.set_trace()
chunks['text'][i] = chunk_text[0]
chunks['length'][i] -= 2
edits.append(True)
else:
edits.append(False)
return chunks, edits
def get_null_data(model_name, topic):
data = None
try:
file = open(f'nullData/{model_name}_{topic}.pkl', 'rb')
data = pickle.load(file)
except:
pass
return data
def get_threshold_obj(model_name, topic):
threshold = None
try:
file = open('threshold_obj.pkl', 'rb')
threshold_obj = pickle.load(file)
threshold = threshold_obj[model_name][topic]
except:
pass
return threshold
def detect_human_text(model_name, topic, threshold, text):
torch.cuda.empty_cache()
gc.collect()
# Get null data
print('Get null data')
df_null = get_null_data(model_name, topic)
if 'num' in df_null.columns:
df_null = df_null[df_null.num > 1]
# Get survival function
print('Get survival function')
pval_functions = get_survival_function(df_null, G=43)
min_tokens_per_sentence = 10
max_tokens_per_sentence = 50
cache_dir_tokenizer = f"/data/cacheHuggingface/{model_name}/tokenizer"
# Init model
print('Init tokenizer')
lm_name = 'gpt2-xl' if model_name == 'GPT2XL' else 'microsoft/phi-2'
tokenizer = AutoTokenizer.from_pretrained(cache_dir_tokenizer) #lm_name, cache_dir=cache_dir_tokenizer
# print("Save tokenizer")
# tokenizer.save_pretrained(cache_dir_tokenizer)
print('Init model')
cache_dir_model = f"/data/cacheHuggingface/{model_name}/model"
model = AutoModelForCausalLM.from_pretrained(cache_dir_model) #lm_name, cache_dir=cache_dir_model
# print("Save model")
# model.save_pretrained(cache_dir_model)
print('Init PerplexityEvaluator')
sentence_detector = PerplexityEvaluator(model, tokenizer)
if torch.backends.mps.is_available():
device = 'mps'
elif torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
print(f'device {device}')
model.to(device)
print('Init DetectLM')
detector = DetectLM(sentence_detector, pval_functions,
min_len=min_tokens_per_sentence,
max_len=max_tokens_per_sentence,
length_limit_policy='max_available',
HC_type='not stbl',
gamma=0.15,
ignore_first_sentence= False
)
# Convert text to object
print('Analyze text')
article_obj = get_article_obj(text)
parser = PrepareArticles(article_obj, min_tokens=min_tokens_per_sentence, max_tokens=max_tokens_per_sentence)
chunks = parser(combined=False)
# Go over all the document
for i in range(len(chunks['text'])):
print(chunks['text'][i])
# for p,v in enumerate(chunks['text'][i]):
# print(f'{p}: {v}')
res = detector(chunks['text'][i], chunks['context'][i], dashboard=None)
# print(f"Num of Edits (rate) = {np.sum(df['tag'] == '<edit>')} ({edit_rate})")
# print(f"HC = {res['HC']}")
# print(f"Fisher = {res['fisher']}")
# print(f"Fisher (chisquared pvalue) = {res['fisher_pvalue']}")
results = res['HC']
# threshold = get_threshold_obj(model_name, topic)
print(f"threshold: {threshold}, results: {results}")
return round((results / threshold) - 1, 2), round(results, 2), res['sentences']
# Convert article text into object
def get_article_obj(text):
# Init article object
article_obj = {
'sub_titles': [{
'sentences': []
}]
}
nlp = spacy.load("en_core_web_sm") # Load model
for line in text.split('\n'):
doc = nlp(line) # Analyze text
sentences = [sent.text for sent in doc.sents if len(sent) >= 10] # Split it by sentence
for sentence in sentences:
sentence = re.sub(r' +', ' ', sentence) # Remove duplicate spaces
article_obj['sub_titles'][0]['sentences'].append({'sentence': sentence})
return article_obj