Moo commited on
Commit
8a1e4a5
·
1 Parent(s): 0312e9a

Create new file

Browse files
Files changed (1) hide show
  1. correct.py +39 -0
correct.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ from transformers import AutoTokenizer, GPT2LMHeadModel
4
+
5
+ O_TKN = '<origin>'
6
+ C_TKN = '<correct>'
7
+ BOS = "</s>"
8
+ EOS = "</s>"
9
+ PAD = "<pad>"
10
+ MASK = '<unused0>'
11
+ SENT = '<unused1>'
12
+
13
+
14
+ def chat():
15
+ tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2',
16
+ eos_token=EOS, unk_token='<unk>',
17
+ pad_token=PAD, mask_token=MASK)
18
+ model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader')
19
+ with torch.no_grad():
20
+ while True:
21
+ q = input('원래문장: ').strip()
22
+ if q == 'quit':
23
+ break
24
+ a = ''
25
+ while True:
26
+ input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0)
27
+ pred = model(input_ids)
28
+ gen = tokenizer.convert_ids_to_tokens(
29
+ torch.argmax(
30
+ pred[0],
31
+ dim=-1).squeeze().numpy().tolist())[-1]
32
+ if gen == EOS:
33
+ break
34
+ a += gen.replace('▁', ' ')
35
+ print(f"교정: {a.strip()}")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ chat()