File size: 6,304 Bytes
f75d1f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import re, nltk, spacy, textstat, subprocess
from nltk import FreqDist
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from tqdm import tqdm
import gradio as gr
import plotly.graph_objects as go

nltk.download('stopwords')
nltk.download('punkt')
nlp = spacy.load("en_core_web_sm")
command = ['python', '-m', 'spacy', 'download', 'en_core_web_sm', '-q']

# Execute the command
subprocess.run(command)

# for perplexity
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

def normalize(value, min_value, max_value):
    normalized_value = ((value - min_value) * 100) / (max_value - min_value)
    return max(0, min(100, normalized_value))

# vocabulary richness
def preprocess_text1(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text) # remove punctuation
    stop_words = set(stopwords.words('english')) # remove stopwords
    words = [word for word in text.split() if word not in stop_words]
    words = [word for word in words if not word.isdigit()] # remove numbers
    return words

def vocabulary_richness_ttr(words):
    unique_words = set(words)
    ttr = len(unique_words) / len(words) * 100
    return ttr

def calculate_gunning_fog(text):
    """range 0-20"""
    gunning_fog = textstat.gunning_fog(text)
    return gunning_fog

def calculate_automated_readability_index(text):
    """range 1-20"""
    ari = textstat.automated_readability_index(text)
    return ari

def calculate_flesch_reading_ease(text):
    """range 0-100"""
    fre = textstat.flesch_reading_ease(text)
    return fre

def preprocess_text2(text):
    # tokenize into words and remove punctuation
    sentences = sent_tokenize(text)
    words = [word.lower() for sent in sentences for word in word_tokenize(sent) if word.isalnum()]
    # remove stopwords
    stop_words = set(stopwords.words('english'))
    words = [word for word in words if word not in stop_words]
    return words, sentences

def calculate_average_sentence_length(sentences):
    """range 0-40 or 50 based on the histogram"""
    total_words = sum(len(word_tokenize(sent)) for sent in sentences)
    average_sentence_length = total_words / (len(sentences) + 0.0000001)
    return average_sentence_length

def calculate_average_word_length(words):
    """range 0-8 based on the histogram"""
    total_characters = sum(len(word) for word in words)
    average_word_length = total_characters / (len(words) + 0.0000001)
    return average_word_length

def calculate_max_depth(sent):
    return max(len(list(token.ancestors)) for token in sent)

def calculate_syntactic_tree_depth(text):
    """0-10 based on the histogram"""
    doc = nlp(text)
    sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
    average_depth = sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
    return average_depth

# reference: https://huggingface.co/docs/transformers/perplexity
def calculate_perplexity(text, stride=512):
    """range 0-30 based on the histogram"""
    encodings = tokenizer(text, return_tensors="pt")
    max_length = model.config.n_positions
    seq_len = encodings.input_ids.size(1)

    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    return ppl.item()


def radar_plot(input_text):

    # vocanulary richness
    processed_words = preprocess_text1(input_text)
    ttr_value = vocabulary_richness_ttr(processed_words)

    # readability
    gunning_fog = calculate_gunning_fog(input_text)
    gunning_fog_norm = normalize(gunning_fog, min_value=0, max_value=20)
    
    # average sentence length and average word length
    words, sentences = preprocess_text2(input_text)
    average_sentence_length = calculate_average_sentence_length(sentences)
    average_word_length = calculate_average_word_length(words)
    average_sentence_length_norm = normalize(average_sentence_length, min_value=0, max_value=40)
    average_word_length_norm = normalize(average_word_length, min_value=0, max_value=8)

    # syntactic_tree_depth
    average_tree_depth = calculate_syntactic_tree_depth(input_text)
    average_tree_depth_norm = normalize(average_tree_depth, min_value=0, max_value=10)

    # perplexity
    perplexity = calculate_perplexity(input_text)
    perplexity_norm = normalize(perplexity, min_value=0, max_value=30)

    features = {
        "readability": gunning_fog_norm, 
        "syntactic tree depth": average_tree_depth_norm,
        "vocabulary richness": ttr_value,
        "perplexity": perplexity_norm,
        "average sentence length": average_sentence_length_norm,
        "average word length": average_word_length_norm, 
    }

    print(features)

    fig = go.Figure()

    fig.add_trace(go.Scatterpolar(
        r=list(features.values()),
        theta=list(features.keys()),
        fill='toself',
        name='Radar Plot'
    ))

    fig.update_layout(
        polar=dict(
            radialaxis=dict(
                visible=True,
                range=[0, 100],
            )),
        showlegend=False,
        # autosize=False,
        # width=600,
        # height=600,
        margin=dict(
            l=10,
            r=20,
            b=10,
            t=10,
            # pad=100
        ),
    )

    return fig

# Gradio Interface
interface = gr.Interface(
    fn=radar_plot,
    inputs=gr.Textbox(label="Input text"),
    outputs=gr.Plot(label="Radar Plot"),
    title="Writing analysis",
    description="Enter text for writing analysis",
)

interface.launch()