|
--- |
|
datasets: |
|
- liweili/c4_200m |
|
language: |
|
- en |
|
--- |
|
|
|
```python |
|
# Load model directly |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("thenHung/english-grammar-error-correction-t5-seq2seq") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("thenHung/english-grammar-error-correction-t5-seq2seq").to(torch_device) |
|
|
|
|
|
def correct_grammar(input_text,num_return_sequences): |
|
batch = tokenizer([input_text],truncation=True,padding='max_length',max_length=64, return_tensors="pt").to(torch_device) |
|
translated = model.generate(**batch,max_length=64,num_beams=4, num_return_sequences=num_return_sequences, temperature=1.5) |
|
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) |
|
return tgt_text |
|
|
|
input_text = """ |
|
He are an teachers. |
|
""" |
|
num_return_sequences = 3 |
|
corrected_texts = correct_grammar(input_text, num_return_sequences) |
|
print(corrected_texts) |
|
|
|
# output: |
|
# ['He is a teacher.', 'He is an educator.', 'He is one of the teachers.'] |
|
``` |