aehrm's picture
streaming
aa2ec94
import gradio as gr
from nltk.tokenize.treebank import TreebankWordDetokenizer
from somajo import SoMaJo
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextStreamer, TextIteratorStreamer
from threading import Thread
from datasets import Dataset
from transformers.pipelines.pt_utils import KeyDataset
from hybrid_textnorm.lexicon import Lexicon
from hybrid_textnorm.normalization import predict_type_normalization, reranked_normalization, prior_normalization
from hybrid_textnorm.preprocess import recombine_tokens, german_transliterate
from tqdm import tqdm
import re
from collections import Counter
text_tokenizer = SoMaJo("de_CMC", split_camel_case=True)
lexicon_dataset_name = 'aehrm/dtaec-lexicon'
train_lexicon = Lexicon.from_dataset(lexicon_dataset_name, split='train')
detok = TreebankWordDetokenizer()
def predict(input_str, model_name, progress=gr.Progress()):
tokenized_sentences = list(text_tokenizer.tokenize_text([input_str]))
if model_name == 'type normalizer':
stream = predict_only_type_transformer(tokenized_sentences, progress)
elif model_name == 'type normalizer + lm':
stream = predict_type_transformer_with_lm(tokenized_sentences, progress)
elif model_name == 'transnormer':
stream = predict_transnormer(tokenized_sentences, progress)
accumulated = ""
for out in stream:
accumulated += out
yield accumulated
def predict_transnormer(tokenized_sentences, progress):
model_name = 'ybracke/transnormer-19c-beta-v02'
#progress(0, desc='loading model')
tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
raw_sentences = []
for tokenized_sent in tokenized_sentences:
sent = ''.join(tok.text + (' ' if tok.space_after else '') for tok in tokenized_sent)
inputs = tokenizer([sent], return_tensors='pt')
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1000, num_beams=1)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield re.sub(r'(<pad>|</s>)', '', new_text)
yield '\n'
def predict_only_type_transformer(tokenized_sentences, progress):
type_model_name = 'aehrm/dtaec-type-normalizer'
#progress(0, desc='loading model')
type_model_tokenizer = AutoTokenizer.from_pretrained(type_model_name)
type_model = AutoModelForSeq2SeqLM.from_pretrained(type_model_name)
transliterated_sentences = []
for sentence in tokenized_sentences:
transliterated = [german_transliterate(tok.text) for tok in sentence]
oov_replacement_probabilities = {}
oov_types = set(transliterated) - train_lexicon.keys() - oov_replacement_probabilities.keys()
#print('oov:', oov_types)
for input_type, probas in predict_type_normalization(oov_types, type_model_tokenizer, type_model, batch_size=8):
oov_replacement_probabilities[input_type] = probas
output_sent = []
for t in transliterated:
if t in train_lexicon.keys():
output_sent.append(train_lexicon[t].most_common(1)[0][0])
elif t in oov_replacement_probabilities.keys():
output_sent.append(Counter(dict(oov_replacement_probabilities[t])).most_common(1)[0][0])
else:
raise ValueError()
yield detok.detokenize(recombine_tokens(output_sent)) + '\n'
def predict_type_transformer_with_lm(tokenized_sentences, progress):
type_model_name = 'aehrm/dtaec-type-normalizer'
language_model_name = 'dbmdz/german-gpt2'
#progress(0, desc='loading model')
type_model_tokenizer = AutoTokenizer.from_pretrained(type_model_name)
type_model = AutoModelForSeq2SeqLM.from_pretrained(type_model_name)
language_model_tokenizer = AutoTokenizer.from_pretrained(language_model_name)
language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
if 'pad_token' not in language_model_tokenizer.special_tokens_map:
language_model_tokenizer.add_special_tokens({'pad_token': '<pad>'})
oov_replacement_probabilities = {}
for sentence in tokenized_sentences:
transliterated = [german_transliterate(tok.text) for tok in sentence]
oov_types = set(transliterated) - train_lexicon.keys() - oov_replacement_probabilities.keys()
#print('oov:', oov_types)
for input_type, probas in predict_type_normalization(oov_types, type_model_tokenizer, type_model, batch_size=8):
oov_replacement_probabilities[input_type] = probas
predictions = reranked_normalization(transliterated, train_lexicon, oov_replacement_probabilities, language_model_tokenizer, language_model, batch_size=1)
best_pred, _, _, _ = predictions[0]
yield detok.detokenize(recombine_tokens(best_pred)) + '\n'
gradio_app = gr.Interface(
predict,
inputs=[gr.Textbox(label="Input", value="Die Königinn ſaß auf des Pallaſtes mittlerer Tribune."), gr.Dropdown([('aehrm/dtaec-type-normalizer (FAST)', 'type normalizer'), ('aehrm/dtaec-type-normalizer + dbmdz/german-gpt2 (Fast)', 'type normalizer + lm'), ('ybracke/transnormer-19c-beta-v02 (fast)', 'transnormer')], label="Model")],
outputs=gr.Textbox(label="Output", show_label=True),
title="German Historical Text Normalization",
)
if __name__ == "__main__":
gradio_app.launch()