sashdev commited on
Commit
19f79d5
·
verified ·
1 Parent(s): 043440f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -1,32 +1,33 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
- # Load the grammar correction model and tokenizer
5
  model_name = "hassaanik/grammar-correction-model"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
- # Function for grammar correction using the grammar correction model
10
- def correct_grammar(text):
11
- # Tokenize the input text
12
- inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
13
 
14
- # Generate the corrected output from the model
15
- outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True)
 
16
 
17
- # Decode the generated tokens to get the corrected text
 
 
 
 
 
 
 
 
18
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
-
20
  return corrected_text
21
 
22
- # Gradio interface for the grammar correction app
23
- interface = gr.Interface(
24
- fn=correct_grammar,
25
- inputs="text",
26
- outputs="text",
27
- title="Grammar Correction App",
28
- description="Enter a sentence or paragraph to get grammar corrections using a Seq2Seq grammar correction model."
29
- )
30
-
31
  if __name__ == "__main__":
32
- interface.launch()
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
 
4
+ # Load model and tokenizer
5
  model_name = "hassaanik/grammar-correction-model"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
7
 
8
+ # Use GPU if available, otherwise fall back to CPU
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
 
11
 
12
+ # Use FP16 for faster inference on GPU
13
+ if torch.cuda.is_available():
14
+ model.half()
15
 
16
+ # Function to correct grammar for a single text input
17
+ def correct_grammar(text):
18
+ # Tokenize input and move it to the correct device (CPU/GPU)
19
+ inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device)
20
+
21
+ # Generate corrected output with beam search
22
+ outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True)
23
+
24
+ # Decode output and return corrected text
25
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
26
  return corrected_text
27
 
28
+ # Example usage of the grammar correction function
 
 
 
 
 
 
 
 
29
  if __name__ == "__main__":
30
+ sample_text = "He go to the market yesturday."
31
+ corrected_text = correct_grammar(sample_text)
32
+ print("Original Text:", sample_text)
33
+ print("Corrected Text:", corrected_text)