sashdev commited on
Commit
fdeaa3e
·
verified ·
1 Parent(s): 3105a3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -1,34 +1,35 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
- # Load the grammar correction model
6
- model_name = "microsoft/deberta-v3-base"
 
 
7
 
8
- # Disable fast tokenization by setting `use_fast=False`
9
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
10
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
-
12
- # Function to correct grammar
13
  def correct_grammar(text):
14
- # Encode input text
15
- inputs = tokenizer.encode(text, return_tensors="pt")
16
-
17
- # Generate the corrected text
18
- with torch.no_grad():
19
- outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True)
20
 
21
- # Decode the corrected text
22
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
23
  return corrected_text
24
 
25
- # Gradio Interface
26
  interface = gr.Interface(
27
  fn=correct_grammar,
28
  inputs="text",
29
  outputs="text",
30
- title="Grammar Correction",
31
- description="Enter a sentence or paragraph to receive grammar corrections using DeBERTa."
32
  )
33
 
34
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # Load the GPT model and tokenizer
6
+ model_name = "openai-community/openai-gpt"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ # Function for grammar correction using GPT
 
 
 
 
11
  def correct_grammar(text):
12
+ # Prepare the prompt for grammar correction
13
+ prompt = f"Correct the grammar of the following sentence:\n{text}\nCorrected: "
14
+
15
+ # Encode the input text and generate output
16
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
17
+ outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True)
18
 
19
+ # Decode the generated text and return the corrected sentence
20
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+
22
+ # Post-process the output to extract the corrected sentence
23
+ corrected_text = corrected_text.replace(prompt, "").strip() # Clean up the result
24
  return corrected_text
25
 
26
+ # Gradio interface for the grammar correction app
27
  interface = gr.Interface(
28
  fn=correct_grammar,
29
  inputs="text",
30
  outputs="text",
31
+ title="Grammar Correction with GPT",
32
+ description="Enter a sentence or paragraph to receive grammar corrections using the OpenAI GPT model."
33
  )
34
 
35
  if __name__ == "__main__":