Spaces:
Runtime error
Runtime error
File size: 6,304 Bytes
f75d1f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import re, nltk, spacy, textstat, subprocess
from nltk import FreqDist
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from tqdm import tqdm
import gradio as gr
import plotly.graph_objects as go
nltk.download('stopwords')
nltk.download('punkt')
nlp = spacy.load("en_core_web_sm")
command = ['python', '-m', 'spacy', 'download', 'en_core_web_sm', '-q']
# Execute the command
subprocess.run(command)
# for perplexity
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
def normalize(value, min_value, max_value):
normalized_value = ((value - min_value) * 100) / (max_value - min_value)
return max(0, min(100, normalized_value))
# vocabulary richness
def preprocess_text1(text):
text = text.lower()
text = re.sub(r'[^\w\s]', '', text) # remove punctuation
stop_words = set(stopwords.words('english')) # remove stopwords
words = [word for word in text.split() if word not in stop_words]
words = [word for word in words if not word.isdigit()] # remove numbers
return words
def vocabulary_richness_ttr(words):
unique_words = set(words)
ttr = len(unique_words) / len(words) * 100
return ttr
def calculate_gunning_fog(text):
"""range 0-20"""
gunning_fog = textstat.gunning_fog(text)
return gunning_fog
def calculate_automated_readability_index(text):
"""range 1-20"""
ari = textstat.automated_readability_index(text)
return ari
def calculate_flesch_reading_ease(text):
"""range 0-100"""
fre = textstat.flesch_reading_ease(text)
return fre
def preprocess_text2(text):
# tokenize into words and remove punctuation
sentences = sent_tokenize(text)
words = [word.lower() for sent in sentences for word in word_tokenize(sent) if word.isalnum()]
# remove stopwords
stop_words = set(stopwords.words('english'))
words = [word for word in words if word not in stop_words]
return words, sentences
def calculate_average_sentence_length(sentences):
"""range 0-40 or 50 based on the histogram"""
total_words = sum(len(word_tokenize(sent)) for sent in sentences)
average_sentence_length = total_words / (len(sentences) + 0.0000001)
return average_sentence_length
def calculate_average_word_length(words):
"""range 0-8 based on the histogram"""
total_characters = sum(len(word) for word in words)
average_word_length = total_characters / (len(words) + 0.0000001)
return average_word_length
def calculate_max_depth(sent):
return max(len(list(token.ancestors)) for token in sent)
def calculate_syntactic_tree_depth(text):
"""0-10 based on the histogram"""
doc = nlp(text)
sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
average_depth = sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
return average_depth
# reference: https://huggingface.co/docs/transformers/perplexity
def calculate_perplexity(text, stride=512):
"""range 0-30 based on the histogram"""
encodings = tokenizer(text, return_tensors="pt")
max_length = model.config.n_positions
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
return ppl.item()
def radar_plot(input_text):
# vocanulary richness
processed_words = preprocess_text1(input_text)
ttr_value = vocabulary_richness_ttr(processed_words)
# readability
gunning_fog = calculate_gunning_fog(input_text)
gunning_fog_norm = normalize(gunning_fog, min_value=0, max_value=20)
# average sentence length and average word length
words, sentences = preprocess_text2(input_text)
average_sentence_length = calculate_average_sentence_length(sentences)
average_word_length = calculate_average_word_length(words)
average_sentence_length_norm = normalize(average_sentence_length, min_value=0, max_value=40)
average_word_length_norm = normalize(average_word_length, min_value=0, max_value=8)
# syntactic_tree_depth
average_tree_depth = calculate_syntactic_tree_depth(input_text)
average_tree_depth_norm = normalize(average_tree_depth, min_value=0, max_value=10)
# perplexity
perplexity = calculate_perplexity(input_text)
perplexity_norm = normalize(perplexity, min_value=0, max_value=30)
features = {
"readability": gunning_fog_norm,
"syntactic tree depth": average_tree_depth_norm,
"vocabulary richness": ttr_value,
"perplexity": perplexity_norm,
"average sentence length": average_sentence_length_norm,
"average word length": average_word_length_norm,
}
print(features)
fig = go.Figure()
fig.add_trace(go.Scatterpolar(
r=list(features.values()),
theta=list(features.keys()),
fill='toself',
name='Radar Plot'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100],
)),
showlegend=False,
# autosize=False,
# width=600,
# height=600,
margin=dict(
l=10,
r=20,
b=10,
t=10,
# pad=100
),
)
return fig
# Gradio Interface
interface = gr.Interface(
fn=radar_plot,
inputs=gr.Textbox(label="Input text"),
outputs=gr.Plot(label="Radar Plot"),
title="Writing analysis",
description="Enter text for writing analysis",
)
interface.launch()
|