copyright_checker / explainability.py
minko186's picture
refactoring
45d10c4
raw
history blame
3.5 kB
import re, textstat
from nltk import FreqDist
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
import torch
import nltk
from tqdm import tqdm
nltk.download("punkt")
def normalize(value, min_value, max_value):
normalized_value = ((value - min_value) * 100) / (max_value - min_value)
return max(0, min(100, normalized_value))
def preprocess_text1(text):
text = text.lower()
text = re.sub(r"[^\w\s]", "", text) # remove punctuation
stop_words = set(stopwords.words("english")) # remove stopwords
words = [word for word in text.split() if word not in stop_words]
words = [word for word in words if not word.isdigit()] # remove numbers
return words
def vocabulary_richness_ttr(words):
unique_words = set(words)
ttr = len(unique_words) / len(words) * 100
return ttr
def calculate_gunning_fog(text):
"""range 0-20"""
gunning_fog = textstat.gunning_fog(text)
return gunning_fog
def calculate_automated_readability_index(text):
"""range 1-20"""
ari = textstat.automated_readability_index(text)
return ari
def calculate_flesch_reading_ease(text):
"""range 0-100"""
fre = textstat.flesch_reading_ease(text)
return fre
def preprocess_text2(text):
sentences = sent_tokenize(text)
words = [
word.lower()
for sent in sentences
for word in word_tokenize(sent)
if word.isalnum()
]
stop_words = set(stopwords.words("english"))
words = [word for word in words if word not in stop_words]
return words, sentences
def calculate_average_sentence_length(sentences):
"""range 0-40 or 50 based on the histogram"""
total_words = sum(len(word_tokenize(sent)) for sent in sentences)
average_sentence_length = total_words / (len(sentences) + 0.0000001)
return average_sentence_length
def calculate_average_word_length(words):
"""range 0-8 based on the histogram"""
total_characters = sum(len(word) for word in words)
average_word_length = total_characters / (len(words) + 0.0000001)
return average_word_length
def calculate_max_depth(sent):
return max(len(list(token.ancestors)) for token in sent)
def calculate_syntactic_tree_depth(nlp, text):
"""0-10 based on the histogram"""
doc = nlp(text)
sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
average_depth = (
sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
)
return average_depth
def calculate_perplexity(text, model, tokenizer, device, stride=512):
"""range 0-30 based on the histogram"""
encodings = tokenizer(text, return_tensors="pt")
max_length = model.config.n_positions
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = (
end_loc - prev_end_loc
) # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
return ppl.item()