copyright_checker / writing_analysis.py
eljanmahammadli's picture
added writing analysis code
f75d1f0
raw
history blame
6.3 kB
import re, nltk, spacy, textstat, subprocess
from nltk import FreqDist
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from tqdm import tqdm
import gradio as gr
import plotly.graph_objects as go
nltk.download('stopwords')
nltk.download('punkt')
nlp = spacy.load("en_core_web_sm")
command = ['python', '-m', 'spacy', 'download', 'en_core_web_sm', '-q']
# Execute the command
subprocess.run(command)
# for perplexity
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
def normalize(value, min_value, max_value):
normalized_value = ((value - min_value) * 100) / (max_value - min_value)
return max(0, min(100, normalized_value))
# vocabulary richness
def preprocess_text1(text):
text = text.lower()
text = re.sub(r'[^\w\s]', '', text) # remove punctuation
stop_words = set(stopwords.words('english')) # remove stopwords
words = [word for word in text.split() if word not in stop_words]
words = [word for word in words if not word.isdigit()] # remove numbers
return words
def vocabulary_richness_ttr(words):
unique_words = set(words)
ttr = len(unique_words) / len(words) * 100
return ttr
def calculate_gunning_fog(text):
"""range 0-20"""
gunning_fog = textstat.gunning_fog(text)
return gunning_fog
def calculate_automated_readability_index(text):
"""range 1-20"""
ari = textstat.automated_readability_index(text)
return ari
def calculate_flesch_reading_ease(text):
"""range 0-100"""
fre = textstat.flesch_reading_ease(text)
return fre
def preprocess_text2(text):
# tokenize into words and remove punctuation
sentences = sent_tokenize(text)
words = [word.lower() for sent in sentences for word in word_tokenize(sent) if word.isalnum()]
# remove stopwords
stop_words = set(stopwords.words('english'))
words = [word for word in words if word not in stop_words]
return words, sentences
def calculate_average_sentence_length(sentences):
"""range 0-40 or 50 based on the histogram"""
total_words = sum(len(word_tokenize(sent)) for sent in sentences)
average_sentence_length = total_words / (len(sentences) + 0.0000001)
return average_sentence_length
def calculate_average_word_length(words):
"""range 0-8 based on the histogram"""
total_characters = sum(len(word) for word in words)
average_word_length = total_characters / (len(words) + 0.0000001)
return average_word_length
def calculate_max_depth(sent):
return max(len(list(token.ancestors)) for token in sent)
def calculate_syntactic_tree_depth(text):
"""0-10 based on the histogram"""
doc = nlp(text)
sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
average_depth = sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
return average_depth
# reference: https://huggingface.co/docs/transformers/perplexity
def calculate_perplexity(text, stride=512):
"""range 0-30 based on the histogram"""
encodings = tokenizer(text, return_tensors="pt")
max_length = model.config.n_positions
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
return ppl.item()
def radar_plot(input_text):
# vocanulary richness
processed_words = preprocess_text1(input_text)
ttr_value = vocabulary_richness_ttr(processed_words)
# readability
gunning_fog = calculate_gunning_fog(input_text)
gunning_fog_norm = normalize(gunning_fog, min_value=0, max_value=20)
# average sentence length and average word length
words, sentences = preprocess_text2(input_text)
average_sentence_length = calculate_average_sentence_length(sentences)
average_word_length = calculate_average_word_length(words)
average_sentence_length_norm = normalize(average_sentence_length, min_value=0, max_value=40)
average_word_length_norm = normalize(average_word_length, min_value=0, max_value=8)
# syntactic_tree_depth
average_tree_depth = calculate_syntactic_tree_depth(input_text)
average_tree_depth_norm = normalize(average_tree_depth, min_value=0, max_value=10)
# perplexity
perplexity = calculate_perplexity(input_text)
perplexity_norm = normalize(perplexity, min_value=0, max_value=30)
features = {
"readability": gunning_fog_norm,
"syntactic tree depth": average_tree_depth_norm,
"vocabulary richness": ttr_value,
"perplexity": perplexity_norm,
"average sentence length": average_sentence_length_norm,
"average word length": average_word_length_norm,
}
print(features)
fig = go.Figure()
fig.add_trace(go.Scatterpolar(
r=list(features.values()),
theta=list(features.keys()),
fill='toself',
name='Radar Plot'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100],
)),
showlegend=False,
# autosize=False,
# width=600,
# height=600,
margin=dict(
l=10,
r=20,
b=10,
t=10,
# pad=100
),
)
return fig
# Gradio Interface
interface = gr.Interface(
fn=radar_plot,
inputs=gr.Textbox(label="Input text"),
outputs=gr.Plot(label="Radar Plot"),
title="Writing analysis",
description="Enter text for writing analysis",
)
interface.launch()