minko186 commited on
Commit
97dec89
·
1 Parent(s): 2aee0ff

Create explainability.py

Browse files
Files changed (1) hide show
  1. explainability.py +98 -0
explainability.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, textstat
2
+ from nltk import FreqDist
3
+ from nltk.corpus import stopwords
4
+ from nltk.tokenize import word_tokenize, sent_tokenize
5
+ import torch
6
+ import nltk
7
+ from tqdm import tqdm
8
+
9
+ nltk.download('punkt')
10
+
11
+ def normalize(value, min_value, max_value):
12
+ normalized_value = ((value - min_value) * 100) / (max_value - min_value)
13
+ return max(0, min(100, normalized_value))
14
+
15
+ def preprocess_text1(text):
16
+ text = text.lower()
17
+ text = re.sub(r'[^\w\s]', '', text) # remove punctuation
18
+ stop_words = set(stopwords.words('english')) # remove stopwords
19
+ words = [word for word in text.split() if word not in stop_words]
20
+ words = [word for word in words if not word.isdigit()] # remove numbers
21
+ return words
22
+
23
+ def vocabulary_richness_ttr(words):
24
+ unique_words = set(words)
25
+ ttr = len(unique_words) / len(words) * 100
26
+ return ttr
27
+
28
+ def calculate_gunning_fog(text):
29
+ """range 0-20"""
30
+ gunning_fog = textstat.gunning_fog(text)
31
+ return gunning_fog
32
+
33
+ def calculate_automated_readability_index(text):
34
+ """range 1-20"""
35
+ ari = textstat.automated_readability_index(text)
36
+ return ari
37
+
38
+ def calculate_flesch_reading_ease(text):
39
+ """range 0-100"""
40
+ fre = textstat.flesch_reading_ease(text)
41
+ return fre
42
+
43
+ def preprocess_text2(text):
44
+ sentences = sent_tokenize(text)
45
+ words = [word.lower() for sent in sentences for word in word_tokenize(sent) if word.isalnum()]
46
+ stop_words = set(stopwords.words('english'))
47
+ words = [word for word in words if word not in stop_words]
48
+ return words, sentences
49
+
50
+ def calculate_average_sentence_length(sentences):
51
+ """range 0-40 or 50 based on the histogram"""
52
+ total_words = sum(len(word_tokenize(sent)) for sent in sentences)
53
+ average_sentence_length = total_words / (len(sentences) + 0.0000001)
54
+ return average_sentence_length
55
+
56
+ def calculate_average_word_length(words):
57
+ """range 0-8 based on the histogram"""
58
+ total_characters = sum(len(word) for word in words)
59
+ average_word_length = total_characters / (len(words) + 0.0000001)
60
+ return average_word_length
61
+
62
+ def calculate_max_depth(sent):
63
+ return max(len(list(token.ancestors)) for token in sent)
64
+
65
+ def calculate_syntactic_tree_depth(nlp, text):
66
+ """0-10 based on the histogram"""
67
+ doc = nlp(text)
68
+ sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
69
+ average_depth = sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
70
+ return average_depth
71
+
72
+ def calculate_perplexity(text, model, tokenizer, device, stride=512):
73
+ """range 0-30 based on the histogram"""
74
+ encodings = tokenizer(text, return_tensors="pt")
75
+ max_length = model.config.n_positions
76
+ seq_len = encodings.input_ids.size(1)
77
+
78
+ nlls = []
79
+ prev_end_loc = 0
80
+ for begin_loc in tqdm(range(0, seq_len, stride)):
81
+ end_loc = min(begin_loc + max_length, seq_len)
82
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
83
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
84
+ target_ids = input_ids.clone()
85
+ target_ids[:, :-trg_len] = -100
86
+
87
+ with torch.no_grad():
88
+ outputs = model(input_ids, labels=target_ids)
89
+ neg_log_likelihood = outputs.loss
90
+
91
+ nlls.append(neg_log_likelihood)
92
+
93
+ prev_end_loc = end_loc
94
+ if end_loc == seq_len:
95
+ break
96
+
97
+ ppl = torch.exp(torch.stack(nlls).mean())
98
+ return ppl.item()