sashdev commited on
Commit
a4b85c4
·
verified ·
1 Parent(s): 625eebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -37
app.py CHANGED
@@ -1,46 +1,110 @@
1
- # Imports
 
 
 
 
 
 
2
  import gradio as gr
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
- # Load the tokenizer and model
7
- tokenizer = AutoTokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
8
- model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/grammar_error_correcter_v1")
 
9
 
10
- # Use GPU if available
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model.to(device)
13
 
14
- # Grammar correction function
15
- def correct_grammar(text):
16
- # Tokenize input text with an increased max_length for handling larger input
17
- inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
18
-
19
- # Generate corrected text with increased max_length and num_beams
20
- outputs = model.generate(**inputs, max_length=2024, num_beams=5, early_stopping=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Decode the output and return the corrected text
23
- corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- return corrected_text
25
-
26
- # Gradio interface function
27
- def correct_grammar_interface(text):
28
- corrected_text = correct_grammar(text)
29
- return corrected_text
30
-
31
- # Gradio app interface
32
- with gr.Blocks() as grammar_app:
33
- gr.Markdown("<h1>Grammar Correction App (up to 300 words)</h1>")
34
 
35
- with gr.Row():
36
- input_box = gr.Textbox(label="Input Text", placeholder="Enter text (up to 300 words)", lines=10)
37
- output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- submit_button = gr.Button("Correct Grammar")
 
 
 
 
40
 
41
- # Bind the button click to the grammar correction function
42
- submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
43
 
44
- # Launch the app
45
- if __name__ == "__main__":
46
- grammar_app.launch()
 
1
+ import os
2
+ import random
3
+ import re
4
+ import string
5
+ import spacy
6
+ from nltk.corpus import wordnet
7
+ import nltk
8
  import gradio as gr
 
 
9
 
10
+ # Ensure that necessary NLTK resources are downloaded
11
+ nltk.download('punkt')
12
+ nltk.download('averaged_perceptron_tagger')
13
+ nltk.download('wordnet')
14
 
15
+ # Load SpaCy model
16
+ nlp = spacy.load("en_core_web_sm")
 
17
 
18
+ # Exclude tags and words (adjusted for better precision)
19
+ exclude_tags = {'PRP', 'PRP$', 'MD', 'VBZ', 'VBP', 'VBD', 'VBG', 'VBN', 'TO', 'IN', 'DT', 'CC'}
20
+ exclude_words = {'is', 'am', 'are', 'was', 'were', 'have', 'has', 'do', 'does', 'did', 'will', 'shall', 'should', 'would', 'could', 'can', 'may', 'might'}
21
+
22
+ def get_synonyms(word):
23
+ """Find synonyms for a given word considering the context."""
24
+ synonyms = set()
25
+ for syn in wordnet.synsets(word):
26
+ for lemma in syn.lemmas():
27
+ if "_" not in lemma.name() and lemma.name().isalpha() and lemma.name().lower() != word.lower():
28
+ synonyms.add(lemma.name())
29
+ return synonyms
30
+
31
+ def replace_with_synonyms(word, pos_tag):
32
+ """Replace words with synonyms, keeping the original POS tag."""
33
+ synonyms = get_synonyms(word)
34
+ # Filter by POS tag
35
+ filtered_synonyms = [syn for syn in synonyms if nltk.pos_tag([syn])[0][1] == pos_tag]
36
+ if filtered_synonyms:
37
+ return random.choice(filtered_synonyms)
38
+ return word
39
+
40
+ def improve_paraphrasing_and_grammar(text):
41
+ """Paraphrase and correct grammatical errors in the text."""
42
+ doc = nlp(text)
43
+ corrected_text = []
44
+
45
+ for sent in doc.sents:
46
+ sentence = []
47
+ for token in sent:
48
+ # Replace words with synonyms, excluding special POS tags
49
+ if token.tag_ not in exclude_tags and token.text.lower() not in exclude_words and token.text not in string.punctuation:
50
+ synonym = replace_with_synonyms(token.text, token.tag_)
51
+ sentence.append(synonym if synonym else token.text)
52
+ else:
53
+ sentence.append(token.text)
54
+
55
+ corrected_text.append(' '.join(sentence))
56
 
57
+ # Ensure proper punctuation and capitalization
58
+ final_text = ' '.join(corrected_text)
59
+ final_text = fix_possessives(final_text)
60
+ final_text = fix_punctuation_spacing(final_text)
61
+ final_text = capitalize_sentences(final_text)
62
+ final_text = fix_article_errors(final_text)
 
 
 
 
 
 
63
 
64
+ return final_text
65
+
66
+ def fix_punctuation_spacing(text):
67
+ """Fix spaces before punctuation marks."""
68
+ text = re.sub(r'\s+([,.!?])', r'\1', text)
69
+ return text
70
+
71
+ def fix_possessives(text):
72
+ """Correct possessives like 'John ' s' -> 'John's'."""
73
+ return re.sub(r"(\w)\s?'\s?s", r"\1's", text)
74
+
75
+ def capitalize_sentences(text):
76
+ """Capitalize the first letter of each sentence."""
77
+ return '. '.join([s.capitalize() for s in re.split(r'(?<=\w[.!?])\s+', text)])
78
+
79
+ def fix_article_errors(text):
80
+ """Correct 'a' and 'an' usage based on following word's sound."""
81
+ doc = nlp(text)
82
+ corrected = []
83
+ for token in doc:
84
+ if token.text in ('a', 'an'):
85
+ next_token = token.nbor(1)
86
+ if token.text == "a" and next_token.text[0].lower() in "aeiou":
87
+ corrected.append("an")
88
+ elif token.text == "an" and next_token.text[0].lower() not in "aeiou":
89
+ corrected.append("a")
90
+ else:
91
+ corrected.append(token.text)
92
+ else:
93
+ corrected.append(token.text)
94
+ return ' '.join(corrected)
95
+
96
+ # Gradio app setup
97
+ def gradio_interface(text):
98
+ """Gradio interface function to process the input text."""
99
+ return improve_paraphrasing_and_grammar(text)
100
 
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("## Text Paraphrasing and Grammar Correction")
103
+ text_input = gr.Textbox(lines=10, label='Enter text for paraphrasing and grammar correction')
104
+ text_output = gr.Textbox(lines=10, label='Corrected Text', interactive=False)
105
+ submit_button = gr.Button("🔄 Paraphrase and Correct")
106
 
107
+ submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)
 
108
 
109
+ # Launch the Gradio app
110
+ demo.launch(share=True)