File size: 1,307 Bytes
8a1e4a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
# -*- coding: utf-8 -*-
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel
O_TKN = '<origin>'
C_TKN = '<correct>'
BOS = "</s>"
EOS = "</s>"
PAD = "<pad>"
MASK = '<unused0>'
SENT = '<unused1>'
def chat():
tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2',
eos_token=EOS, unk_token='<unk>',
pad_token=PAD, mask_token=MASK)
model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader')
with torch.no_grad():
while True:
q = input('원래문장: ').strip()
if q == 'quit':
break
a = ''
while True:
input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0)
pred = model(input_ids)
gen = tokenizer.convert_ids_to_tokens(
torch.argmax(
pred[0],
dim=-1).squeeze().numpy().tolist())[-1]
if gen == EOS:
break
a += gen.replace('▁', ' ')
print(f"교정: {a.strip()}")
if __name__ == "__main__":
chat()
|