File size: 810 Bytes
787a546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import sys
import time
from transformers import logging
from recasepunc import CasePuncPredictor
from recasepunc import WordpieceTokenizer
from recasepunc import Config

logging.set_verbosity_error()

predictor = CasePuncPredictor('checkpoint', lang="en")

text = " ".join(open(sys.argv[1]).readlines())
tokens = list(enumerate(predictor.tokenize(text)))

results = ""
for token, case_label, punc_label in predictor.predict(tokens, lambda x: x[1]):
    prediction = predictor.map_punc_label(predictor.map_case_label(token[1], case_label), punc_label)

    if token[1][0] == '\'' or (len(results) > 0 and results[-1] == '\''):
       results = results + prediction
    elif token[1][0] != '#':
       results = results + ' ' + prediction
    else:
       results = results + prediction

print (results.strip())