File size: 2,772 Bytes
d6a3883
 
2f789a3
55292dd
 
 
2f789a3
70c38a6
 
2f789a3
55292dd
fdab1ea
55292dd
0f6f16f
55292dd
8f0e5c7
 
 
 
 
55292dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f6f16f
fdab1ea
 
 
5d30db9
55292dd
 
04a455f
00a12a3
04a455f
 
fdab1ea
 
 
 
 
 
04a455f
4e606ee
8f0e5c7
 
752d60b
cd58992
8f0e5c7
 
4e606ee
8f0e5c7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-

import gradio as gr
import operator
import torch
from transformers import BertTokenizer, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")


def ai_text(text):
    with torch.no_grad():
        outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))

    def to_highlight(corrected_sent, errs):
        output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in
                  enumerate(errs)]
        return {"text": corrected_sent, "entities": output}

    def get_errors(corrected_text, origin_text):
        sub_details = []
        for i, ori_char in enumerate(origin_text):
            if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
                # add unk word
                corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
                continue
            if i >= len(corrected_text):
                continue
            if ori_char != corrected_text[i]:
                if ori_char.lower() == corrected_text[i]:
                    # pass english upper char
                    corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
                    continue
                sub_details.append((ori_char, corrected_text[i], i, i + 1))
        sub_details = sorted(sub_details, key=operator.itemgetter(2))
        return corrected_text, sub_details

    _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
    corrected_text = _text[:len(text)]
    corrected_text, details = get_errors(corrected_text, text)
    print(text, ' => ', corrected_text, details)
    return corrected_text + ' ' + str(details)


if __name__ == '__main__':
    print(ai_text('少先队员因该为老人让坐'))

    examples = [
        ['真麻烦你了。希望你们好好的跳无'],
        ['少先队员因该为老人让坐'],
        ['机七学习是人工智能领遇最能体现智能的一个分知'],
        ['今天心情很好'],
        ['他法语说的很好,的语也不错'],
        ['他们的吵翻很不错,再说他们做的咖喱鸡也好吃'],
    ]

    gr.Interface(
        ai_text,
        inputs='text',
        outputs='text',
        title="Chinese Spelling Correction Model shibing624/macbert4csc-base-chinese",
        description="Copy or input error Chinese text. Submit and the machine will correct text.",
        article="Link to <a href='https://github.com/shibing624/pycorrector' style='color:blue;' target='_blank\'>Github REPO: pycorrector</a>",
        examples=examples
    ).launch()