sashdev commited on
Commit
c93affb
Β·
verified Β·
1 Parent(s): a0660ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -35
app.py CHANGED
@@ -1,36 +1,226 @@
 
1
  import gradio as gr
2
- from gector.gec_model import GecBERTModel
3
-
4
- # Load the GECToR model
5
- def load_model():
6
- model = GecBERTModel(
7
- vocab_path='data/output_vocabulary',
8
- model_paths=['data/model_files/xlnet_0_gector.th'],
9
- max_len=128, min_len=3
10
- )
11
- return model
12
-
13
- # Function to correct grammar using GECToR model
14
- def correct_grammar(text):
15
- # Initialize the model (you can load it once and use globally to avoid reloading each time)
16
- model = load_model()
17
- # Correct the input text
18
- corrected_text = model.handle_batch([text])
19
- return corrected_text[0] # Since the result is a list, return the first (and only) item
20
-
21
- # Define Gradio interface
22
- def create_gradio_interface():
23
- # Input and output in Gradio
24
- interface = gr.Interface(
25
- fn=correct_grammar, # Function to run
26
- inputs="text", # Input is plain text
27
- outputs="text", # Output is plain text
28
- title="Grammar Correction App using GECToR",
29
- description="Enter your text, and this app will correct its grammar using GECToR."
30
- )
31
- return interface
32
-
33
- # Launch the Gradio app
34
- if __name__ == "__main__":
35
- gradio_interface = create_gradio_interface()
36
- gradio_interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ from transformers import pipeline
4
+ import spacy
5
+ import subprocess
6
+ import nltk
7
+ from nltk.corpus import wordnet
8
+ from nltk.corpus import stopwords
9
+ from nltk.tokenize import word_tokenize
10
+ from spellchecker import SpellChecker
11
+ import re
12
+ import string
13
+ import random
14
+
15
+ # Download necessary NLTK data
16
+ nltk.download('punkt')
17
+ nltk.download('stopwords')
18
+ nltk.download('averaged_perceptron_tagger')
19
+ nltk.download('wordnet')
20
+
21
+ # Initialize stopwords
22
+ stop_words = set(stopwords.words("english"))
23
+
24
+ # Words we don't want to replace
25
+ exclude_tags = {'PRP', 'PRP$', 'MD', 'VBZ', 'VBP', 'VBD', 'VBG', 'VBN', 'TO', 'IN', 'DT', 'CC'}
26
+ exclude_words = {'is', 'am', 'are', 'was', 'were', 'have', 'has', 'do', 'does', 'did', 'will', 'shall', 'should', 'would', 'could', 'can', 'may', 'might'}
27
+
28
+ # Initialize the English text classification pipeline for AI detection
29
+ pipeline_en = pipeline(task="text-classification", model="Hello-SimpleAI/chatgpt-detector-roberta")
30
+
31
+ # Initialize the spell checker
32
+ spell = SpellChecker()
33
+
34
+ # Ensure the SpaCy model is installed
35
+ try:
36
+ nlp = spacy.load("en_core_web_sm")
37
+ except OSError:
38
+ subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
39
+ nlp = spacy.load("en_core_web_sm")
40
+
41
+ def plagiarism_removal(text):
42
+ def plagiarism_remover(word):
43
+ if word.lower() in stop_words or word.lower() in exclude_words or word in string.punctuation:
44
+ return word
45
+
46
+ # Find synonyms
47
+ synonyms = set()
48
+ for syn in wordnet.synsets(word):
49
+ for lemma in syn.lemmas():
50
+ if "_" not in lemma.name() and lemma.name().isalpha() and lemma.name().lower() != word.lower():
51
+ synonyms.add(lemma.name())
52
+
53
+ pos_tag_word = nltk.pos_tag([word])[0]
54
+
55
+ if pos_tag_word[1] in exclude_tags:
56
+ return word
57
+
58
+ filtered_synonyms = [syn for syn in synonyms if nltk.pos_tag([syn])[0][1] == pos_tag_word[1]]
59
+
60
+ if not filtered_synonyms:
61
+ return word
62
+
63
+ synonym_choice = random.choice(filtered_synonyms)
64
+
65
+ if word.istitle():
66
+ return synonym_choice.title()
67
+ return synonym_choice
68
+
69
+ para_split = word_tokenize(text)
70
+ final_text = [plagiarism_remover(word) for word in para_split]
71
+
72
+ corrected_text = []
73
+ for i in range(len(final_text)):
74
+ if final_text[i] in string.punctuation and i > 0:
75
+ corrected_text[-1] += final_text[i]
76
+ else:
77
+ corrected_text.append(final_text[i])
78
+
79
+ return " ".join(corrected_text)
80
+
81
+ def predict_en(text):
82
+ res = pipeline_en(text)[0]
83
+ return res['label'], res['score']
84
+
85
+ def remove_redundant_words(text):
86
+ doc = nlp(text)
87
+ meaningless_words = {"actually", "basically", "literally", "really", "very", "just"}
88
+ filtered_text = [token.text for token in doc if token.text.lower() not in meaningless_words]
89
+ return ' '.join(filtered_text)
90
+
91
+ def fix_punctuation_spacing(text):
92
+ words = text.split(' ')
93
+ cleaned_words = []
94
+ punctuation_marks = {',', '.', "'", '!', '?', ':'}
95
+
96
+ for word in words:
97
+ if cleaned_words and word and word[0] in punctuation_marks:
98
+ cleaned_words[-1] += word
99
+ else:
100
+ cleaned_words.append(word)
101
+
102
+ return ' '.join(cleaned_words).replace(' ,', ',').replace(' .', '.').replace(" '", "'") \
103
+ .replace(' !', '!').replace(' ?', '?').replace(' :', ':')
104
+
105
+ def fix_possessives(text):
106
+ text = re.sub(r'(\w)\s\'\s?s', r"\1's", text)
107
+ return text
108
+
109
+ def capitalize_sentences_and_nouns(text):
110
+ doc = nlp(text)
111
+ corrected_text = []
112
+
113
+ for sent in doc.sents:
114
+ sentence = []
115
+ for token in sent:
116
+ if token.i == sent.start:
117
+ sentence.append(token.text.capitalize())
118
+ elif token.pos_ == "PROPN":
119
+ sentence.append(token.text.capitalize())
120
+ else:
121
+ sentence.append(token.text)
122
+ corrected_text.append(' '.join(sentence))
123
+
124
+ return ' '.join(corrected_text)
125
+
126
+ def force_first_letter_capital(text):
127
+ sentences = re.split(r'(?<=\w[.!?])\s+', text)
128
+ capitalized_sentences = []
129
+
130
+ for sentence in sentences:
131
+ if sentence:
132
+ capitalized_sentence = sentence[0].capitalize() + sentence[1:]
133
+ if not re.search(r'[.!?]$', capitalized_sentence):
134
+ capitalized_sentence += '.'
135
+ capitalized_sentences.append(capitalized_sentence)
136
+
137
+ return " ".join(capitalized_sentences)
138
+
139
+ def correct_tense_errors(text):
140
+ doc = nlp(text)
141
+ corrected_text = []
142
+ for token in doc:
143
+ if token.pos_ == "VERB" and token.dep_ in {"aux", "auxpass"}:
144
+ lemma = wordnet.morphy(token.text, wordnet.VERB) or token.text
145
+ corrected_text.append(lemma)
146
+ else:
147
+ corrected_text.append(token.text)
148
+ return ' '.join(corrected_text)
149
+
150
+ def correct_article_errors(text):
151
+ doc = nlp(text)
152
+ corrected_text = []
153
+ for token in doc:
154
+ if token.text in ['a', 'an']:
155
+ next_token = token.nbor(1)
156
+ if token.text == "a" and next_token.text[0].lower() in "aeiou":
157
+ corrected_text.append("an")
158
+ elif token.text == "an" and next_token.text[0].lower() not in "aeiou":
159
+ corrected_text.append("a")
160
+ else:
161
+ corrected_text.append(token.text)
162
+ else:
163
+ corrected_text.append(token.text)
164
+ return ' '.join(corrected_text)
165
+
166
+ def ensure_subject_verb_agreement(text):
167
+ doc = nlp(text)
168
+ corrected_text = []
169
+ for token in doc:
170
+ if token.dep_ == "nsubj" and token.head.pos_ == "VERB":
171
+ if token.tag_ == "NN" and token.head.tag_ != "VBZ":
172
+ corrected_text.append(token.head.lemma_ + "s")
173
+ elif token.tag_ == "NNS" and token.head.tag_ == "VBZ":
174
+ corrected_text.append(token.head.lemma_)
175
+ corrected_text.append(token.text)
176
+ return ' '.join(corrected_text)
177
+
178
+ def correct_spelling(text):
179
+ words = text.split()
180
+ corrected_words = []
181
+ for word in words:
182
+ corrected_word = spell.correction(word)
183
+ if corrected_word is not None:
184
+ corrected_words.append(corrected_word)
185
+ else:
186
+ corrected_words.append(word)
187
+ return ' '.join(corrected_words)
188
+
189
+ def paraphrase_and_correct(text):
190
+ paragraphs = text.split("\n\n") # Split by paragraphs
191
+
192
+ # Process each paragraph separately
193
+ processed_paragraphs = []
194
+ for paragraph in paragraphs:
195
+ cleaned_text = remove_redundant_words(paragraph)
196
+ plag_removed = plagiarism_removal(cleaned_text)
197
+ paraphrased_text = capitalize_sentences_and_nouns(plag_removed)
198
+ paraphrased_text = force_first_letter_capital(paraphrased_text)
199
+ paraphrased_text = correct_article_errors(paraphrased_text)
200
+ paraphrased_text = correct_tense_errors(paraphrased_text)
201
+ paraphrased_text = ensure_subject_verb_agreement(paraphrased_text)
202
+ paraphrased_text = fix_possessives(paraphrased_text) # Fixed typo here
203
+ paraphrased_text = correct_spelling(paraphrased_text)
204
+ paraphrased_text = fix_punctuation_spacing(paraphrased_text)
205
+ processed_paragraphs.append(paraphrased_text)
206
+
207
+ return "\n\n".join(processed_paragraphs) # Reassemble the text with paragraphs
208
+
209
+ # Gradio app setup
210
+ with gr.Blocks() as demo:
211
+ with gr.Tab("AI Detection"):
212
+ t1 = gr.Textbox(lines=5, label='Text')
213
+ button1 = gr.Button("πŸ€– Predict!")
214
+ label1 = gr.Textbox(lines=1, label='Predicted Label πŸŽƒ')
215
+ score1 = gr.Textbox(lines=1, label='Prob')
216
+
217
+ button1.click(fn=predict_en, inputs=t1, outputs=[label1, score1])
218
+
219
+ with gr.Tab("Paraphrasing & Grammar Correction"):
220
+ t2 = gr.Textbox(lines=5, label='Enter text for paraphrasing and grammar correction')
221
+ button2 = gr.Button("πŸ”„ Paraphrase and Correct")
222
+ result2 = gr.Textbox(lines=5, label='Corrected Text')
223
+
224
+ button2.click(fn=paraphrase_and_correct, inputs=t2, outputs=result2)
225
+
226
+ demo.launch(share=True)