sashdev commited on
Commit
043440f
·
verified ·
1 Parent(s): fdeaa3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,26 +1,22 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- # Load the GPT model and tokenizer
6
- model_name = "openai-community/openai-gpt"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- # Function for grammar correction using GPT
11
  def correct_grammar(text):
12
- # Prepare the prompt for grammar correction
13
- prompt = f"Correct the grammar of the following sentence:\n{text}\nCorrected: "
14
 
15
- # Encode the input text and generate output
16
- inputs = tokenizer.encode(prompt, return_tensors="pt")
17
  outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True)
18
 
19
- # Decode the generated text and return the corrected sentence
20
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
 
22
- # Post-process the output to extract the corrected sentence
23
- corrected_text = corrected_text.replace(prompt, "").strip() # Clean up the result
24
  return corrected_text
25
 
26
  # Gradio interface for the grammar correction app
@@ -28,8 +24,8 @@ interface = gr.Interface(
28
  fn=correct_grammar,
29
  inputs="text",
30
  outputs="text",
31
- title="Grammar Correction with GPT",
32
- description="Enter a sentence or paragraph to receive grammar corrections using the OpenAI GPT model."
33
  )
34
 
35
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
+ # Load the grammar correction model and tokenizer
5
+ model_name = "hassaanik/grammar-correction-model"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
+ # Function for grammar correction using the grammar correction model
10
  def correct_grammar(text):
11
+ # Tokenize the input text
12
+ inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
13
 
14
+ # Generate the corrected output from the model
 
15
  outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True)
16
 
17
+ # Decode the generated tokens to get the corrected text
18
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
 
 
 
20
  return corrected_text
21
 
22
  # Gradio interface for the grammar correction app
 
24
  fn=correct_grammar,
25
  inputs="text",
26
  outputs="text",
27
+ title="Grammar Correction App",
28
+ description="Enter a sentence or paragraph to get grammar corrections using a Seq2Seq grammar correction model."
29
  )
30
 
31
  if __name__ == "__main__":