File size: 5,536 Bytes
a3620e9
 
 
 
aa2ec94
 
a3620e9
 
 
 
 
5af7057
aa2ec94
 
a3620e9
 
 
 
5af7057
 
a3620e9
 
 
 
 
aa2ec94
a3620e9
aa2ec94
a3620e9
aa2ec94
a3620e9
aa2ec94
 
 
 
a3620e9
 
 
 
aa2ec94
 
 
 
a3620e9
 
 
5af7057
 
aa2ec94
a3620e9
aa2ec94
 
 
 
 
 
a3620e9
 
 
 
 
aa2ec94
 
 
a3620e9
 
 
aa2ec94
a3620e9
aa2ec94
 
 
 
 
a3620e9
 
aa2ec94
a3620e9
 
aa2ec94
 
a3620e9
 
 
aa2ec94
a3620e9
 
 
 
 
aa2ec94
a3620e9
 
 
 
 
 
 
 
aa2ec94
 
 
a3620e9
aa2ec94
 
 
a3620e9
aa2ec94
a3620e9
aa2ec94
a3620e9
 
 
 
d2673e5
 
a3620e9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from nltk.tokenize.treebank import TreebankWordDetokenizer
from somajo import SoMaJo

from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextStreamer, TextIteratorStreamer
from threading import Thread
from datasets import Dataset
from transformers.pipelines.pt_utils import KeyDataset
from hybrid_textnorm.lexicon import Lexicon
from hybrid_textnorm.normalization import predict_type_normalization, reranked_normalization, prior_normalization
from hybrid_textnorm.preprocess import recombine_tokens, german_transliterate
from tqdm import tqdm
import re
from collections import Counter

text_tokenizer = SoMaJo("de_CMC", split_camel_case=True)
lexicon_dataset_name = 'aehrm/dtaec-lexicon'
train_lexicon = Lexicon.from_dataset(lexicon_dataset_name, split='train')
detok = TreebankWordDetokenizer()


def predict(input_str, model_name, progress=gr.Progress()):
    tokenized_sentences = list(text_tokenizer.tokenize_text([input_str]))

    if model_name == 'type normalizer':
        stream = predict_only_type_transformer(tokenized_sentences, progress)
    elif model_name == 'type normalizer + lm':
        stream = predict_type_transformer_with_lm(tokenized_sentences, progress)
    elif model_name == 'transnormer':
        stream = predict_transnormer(tokenized_sentences, progress)

    accumulated = ""
    for out in stream:
        accumulated += out
        yield accumulated

def predict_transnormer(tokenized_sentences, progress):
    model_name = 'ybracke/transnormer-19c-beta-v02'

    #progress(0, desc='loading model')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    streamer = TextIteratorStreamer(tokenizer)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    raw_sentences = []
    for tokenized_sent in tokenized_sentences:
        sent = ''.join(tok.text + (' ' if tok.space_after else '') for tok in tokenized_sent)

        inputs = tokenizer([sent], return_tensors='pt')

        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1000, num_beams=1)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        for new_text in streamer:
            yield re.sub(r'(<pad>|</s>)', '', new_text)
        yield '\n'



def predict_only_type_transformer(tokenized_sentences, progress):
    type_model_name = 'aehrm/dtaec-type-normalizer'
    #progress(0, desc='loading model')
    type_model_tokenizer = AutoTokenizer.from_pretrained(type_model_name)
    type_model = AutoModelForSeq2SeqLM.from_pretrained(type_model_name)

    transliterated_sentences = []
    for sentence in tokenized_sentences:
        transliterated = [german_transliterate(tok.text) for tok in sentence]

        oov_replacement_probabilities = {}
        oov_types = set(transliterated) - train_lexicon.keys() - oov_replacement_probabilities.keys()
        #print('oov:', oov_types)
        for input_type, probas in predict_type_normalization(oov_types, type_model_tokenizer, type_model, batch_size=8):
            oov_replacement_probabilities[input_type] = probas

        output_sent = []
        for t in transliterated:
            if t in train_lexicon.keys():
                output_sent.append(train_lexicon[t].most_common(1)[0][0])
            elif t in oov_replacement_probabilities.keys():
                output_sent.append(Counter(dict(oov_replacement_probabilities[t])).most_common(1)[0][0])
            else:
                raise ValueError()

        yield detok.detokenize(recombine_tokens(output_sent)) + '\n'

def predict_type_transformer_with_lm(tokenized_sentences, progress):
    type_model_name = 'aehrm/dtaec-type-normalizer'
    language_model_name = 'dbmdz/german-gpt2'

    #progress(0, desc='loading model')
    type_model_tokenizer = AutoTokenizer.from_pretrained(type_model_name)
    type_model = AutoModelForSeq2SeqLM.from_pretrained(type_model_name)
    language_model_tokenizer = AutoTokenizer.from_pretrained(language_model_name)
    language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
    if 'pad_token' not in language_model_tokenizer.special_tokens_map:
        language_model_tokenizer.add_special_tokens({'pad_token': '<pad>'})

    oov_replacement_probabilities = {}
    for sentence in tokenized_sentences:
        transliterated = [german_transliterate(tok.text) for tok in sentence]
        oov_types = set(transliterated) - train_lexicon.keys() - oov_replacement_probabilities.keys()

        #print('oov:', oov_types)
        for input_type, probas in predict_type_normalization(oov_types, type_model_tokenizer, type_model, batch_size=8):
            oov_replacement_probabilities[input_type] = probas

        predictions = reranked_normalization(transliterated, train_lexicon, oov_replacement_probabilities, language_model_tokenizer, language_model, batch_size=1)
        best_pred, _, _, _ = predictions[0]
        yield detok.detokenize(recombine_tokens(best_pred)) + '\n'
    

gradio_app = gr.Interface(
    predict,
    inputs=[gr.Textbox(label="Input", value="Die Königinn ſaß auf des Pallaſtes mittlerer Tribune."), gr.Dropdown([('aehrm/dtaec-type-normalizer (FAST)', 'type normalizer'), ('aehrm/dtaec-type-normalizer + dbmdz/german-gpt2 (Fast)', 'type normalizer + lm'), ('ybracke/transnormer-19c-beta-v02 (fast)', 'transnormer')], label="Model")],
    outputs=gr.Textbox(label="Output", show_label=True),
    title="German Historical Text Normalization",
)

if __name__ == "__main__":
    gradio_app.launch()