Spaces:
Runtime error
Runtime error
import re, nltk, spacy, textstat, subprocess | |
from nltk import FreqDist | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize, sent_tokenize | |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
import torch | |
from tqdm import tqdm | |
import gradio as gr | |
import plotly.graph_objects as go | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
nlp = spacy.load("en_core_web_sm") | |
command = ['python', '-m', 'spacy', 'download', 'en_core_web_sm', '-q'] | |
# Execute the command | |
subprocess.run(command) | |
# for perplexity | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_id = "gpt2-large" | |
model = GPT2LMHeadModel.from_pretrained(model_id).to(device) | |
tokenizer = GPT2TokenizerFast.from_pretrained(model_id) | |
def normalize(value, min_value, max_value): | |
normalized_value = ((value - min_value) * 100) / (max_value - min_value) | |
return max(0, min(100, normalized_value)) | |
# vocabulary richness | |
def preprocess_text1(text): | |
text = text.lower() | |
text = re.sub(r'[^\w\s]', '', text) # remove punctuation | |
stop_words = set(stopwords.words('english')) # remove stopwords | |
words = [word for word in text.split() if word not in stop_words] | |
words = [word for word in words if not word.isdigit()] # remove numbers | |
return words | |
def vocabulary_richness_ttr(words): | |
unique_words = set(words) | |
ttr = len(unique_words) / len(words) * 100 | |
return ttr | |
def calculate_gunning_fog(text): | |
"""range 0-20""" | |
gunning_fog = textstat.gunning_fog(text) | |
return gunning_fog | |
def calculate_automated_readability_index(text): | |
"""range 1-20""" | |
ari = textstat.automated_readability_index(text) | |
return ari | |
def calculate_flesch_reading_ease(text): | |
"""range 0-100""" | |
fre = textstat.flesch_reading_ease(text) | |
return fre | |
def preprocess_text2(text): | |
# tokenize into words and remove punctuation | |
sentences = sent_tokenize(text) | |
words = [word.lower() for sent in sentences for word in word_tokenize(sent) if word.isalnum()] | |
# remove stopwords | |
stop_words = set(stopwords.words('english')) | |
words = [word for word in words if word not in stop_words] | |
return words, sentences | |
def calculate_average_sentence_length(sentences): | |
"""range 0-40 or 50 based on the histogram""" | |
total_words = sum(len(word_tokenize(sent)) for sent in sentences) | |
average_sentence_length = total_words / (len(sentences) + 0.0000001) | |
return average_sentence_length | |
def calculate_average_word_length(words): | |
"""range 0-8 based on the histogram""" | |
total_characters = sum(len(word) for word in words) | |
average_word_length = total_characters / (len(words) + 0.0000001) | |
return average_word_length | |
def calculate_max_depth(sent): | |
return max(len(list(token.ancestors)) for token in sent) | |
def calculate_syntactic_tree_depth(text): | |
"""0-10 based on the histogram""" | |
doc = nlp(text) | |
sentence_depths = [calculate_max_depth(sent) for sent in doc.sents] | |
average_depth = sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0 | |
return average_depth | |
# reference: https://huggingface.co/docs/transformers/perplexity | |
def calculate_perplexity(text, stride=512): | |
"""range 0-30 based on the histogram""" | |
encodings = tokenizer(text, return_tensors="pt") | |
max_length = model.config.n_positions | |
seq_len = encodings.input_ids.size(1) | |
nlls = [] | |
prev_end_loc = 0 | |
for begin_loc in tqdm(range(0, seq_len, stride)): | |
end_loc = min(begin_loc + max_length, seq_len) | |
trg_len = end_loc - prev_end_loc # may be different from stride on last loop | |
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) | |
target_ids = input_ids.clone() | |
target_ids[:, :-trg_len] = -100 | |
with torch.no_grad(): | |
outputs = model(input_ids, labels=target_ids) | |
neg_log_likelihood = outputs.loss | |
nlls.append(neg_log_likelihood) | |
prev_end_loc = end_loc | |
if end_loc == seq_len: | |
break | |
ppl = torch.exp(torch.stack(nlls).mean()) | |
return ppl.item() | |
def radar_plot(input_text): | |
# vocanulary richness | |
processed_words = preprocess_text1(input_text) | |
ttr_value = vocabulary_richness_ttr(processed_words) | |
# readability | |
gunning_fog = calculate_gunning_fog(input_text) | |
gunning_fog_norm = normalize(gunning_fog, min_value=0, max_value=20) | |
# average sentence length and average word length | |
words, sentences = preprocess_text2(input_text) | |
average_sentence_length = calculate_average_sentence_length(sentences) | |
average_word_length = calculate_average_word_length(words) | |
average_sentence_length_norm = normalize(average_sentence_length, min_value=0, max_value=40) | |
average_word_length_norm = normalize(average_word_length, min_value=0, max_value=8) | |
# syntactic_tree_depth | |
average_tree_depth = calculate_syntactic_tree_depth(input_text) | |
average_tree_depth_norm = normalize(average_tree_depth, min_value=0, max_value=10) | |
# perplexity | |
perplexity = calculate_perplexity(input_text) | |
perplexity_norm = normalize(perplexity, min_value=0, max_value=30) | |
features = { | |
"readability": gunning_fog_norm, | |
"syntactic tree depth": average_tree_depth_norm, | |
"vocabulary richness": ttr_value, | |
"perplexity": perplexity_norm, | |
"average sentence length": average_sentence_length_norm, | |
"average word length": average_word_length_norm, | |
} | |
print(features) | |
fig = go.Figure() | |
fig.add_trace(go.Scatterpolar( | |
r=list(features.values()), | |
theta=list(features.keys()), | |
fill='toself', | |
name='Radar Plot' | |
)) | |
fig.update_layout( | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
range=[0, 100], | |
)), | |
showlegend=False, | |
# autosize=False, | |
# width=600, | |
# height=600, | |
margin=dict( | |
l=10, | |
r=20, | |
b=10, | |
t=10, | |
# pad=100 | |
), | |
) | |
return fig | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=radar_plot, | |
inputs=gr.Textbox(label="Input text"), | |
outputs=gr.Plot(label="Radar Plot"), | |
title="Writing analysis", | |
description="Enter text for writing analysis", | |
) | |
interface.launch() | |