Spaces:
Runtime error
Runtime error
File size: 4,653 Bytes
dee3f71 d8d3649 68c3660 dee3f71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
class Gramformer:
def __init__(self, models=1, use_gpu=False):
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
#from lm_scorer.models.auto import AutoLMScorer as LMScorer
import errant
import en_core_web_sm
nlp = en_core_web_sm.load()
self.annotator = errant.load('en', nlp)
if use_gpu:
device= "cuda:0"
else:
device = "cpu"
batch_size = 1
#self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
self.device = device
correction_model_tag = "prithivida/grammar_error_correcter_v1"
self.model_loaded = False
if models == 1:
self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
self.correction_model = self.correction_model.to(device)
self.model_loaded = True
print("[Gramformer] Grammar error correct/highlight model loaded..")
elif models == 2:
# TODO
print("TO BE IMPLEMENTED!!!")
def correct(self, input_sentence, max_candidates=1):
if self.model_loaded:
correction_prefix = "gec: "
input_sentence = correction_prefix + input_sentence
input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
input_ids = input_ids.to(self.device)
preds = self.correction_model.generate(
input_ids,
do_sample=True,
max_length=128,
# top_k=50,
# top_p=0.95,
num_beams=7,
early_stopping=True,
num_return_sequences=max_candidates)
corrected = set()
for pred in preds:
corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
#corrected = list(corrected)
#scores = self.scorer.sentence_score(corrected, log=True)
#ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
#ranked_corrected.sort(key = lambda x:x[1], reverse=True)
return corrected
else:
print("Model is not loaded")
return None
def highlight(self, orig, cor):
edits = self._get_edits(orig, cor)
orig_tokens = orig.split()
ignore_indexes = []
for edit in edits:
edit_type = edit[0]
edit_str_start = edit[1]
edit_spos = edit[2]
edit_epos = edit[3]
edit_str_end = edit[4]
# if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
for i in range(edit_spos+1, edit_epos):
ignore_indexes.append(i)
if edit_str_start == "":
if edit_spos - 1 >= 0:
new_edit_str = orig_tokens[edit_spos - 1]
edit_spos -= 1
else:
new_edit_str = orig_tokens[edit_spos + 1]
edit_spos += 1
if edit_type == "PUNCT":
st = "<a type='" + edit_type + "' edit='" + \
edit_str_end + "'>" + new_edit_str + "</a>"
else:
st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
" " + edit_str_end + "'>" + new_edit_str + "</a>"
orig_tokens[edit_spos] = st
elif edit_str_end == "":
st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
orig_tokens[edit_spos] = st
else:
st = "<c type='" + edit_type + "' edit='" + \
edit_str_end + "'>" + edit_str_start + "</c>"
orig_tokens[edit_spos] = st
for i in sorted(ignore_indexes, reverse=True):
del(orig_tokens[i])
return(" ".join(orig_tokens))
def detect(self, input_sentence):
# TO BE IMPLEMENTED
pass
def _get_edits(self, orig, cor):
orig = self.annotator.parse(orig)
cor = self.annotator.parse(cor)
alignment = self.annotator.align(orig, cor)
edits = self.annotator.merge(alignment)
if len(edits) == 0:
return []
edit_annotations = []
for e in edits:
e = self.annotator.classify(e)
edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
if len(edit_annotations) > 0:
return edit_annotations
else:
return []
def get_edits(self, orig, cor):
return self._get_edits(orig, cor)
|