aehrm commited on
Commit
aa2ec94
·
1 Parent(s): 600b8f2
Files changed (1) hide show
  1. app.py +44 -54
app.py CHANGED
@@ -2,13 +2,16 @@ import gradio as gr
2
  from nltk.tokenize.treebank import TreebankWordDetokenizer
3
  from somajo import SoMaJo
4
 
5
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
 
6
  from datasets import Dataset
7
  from transformers.pipelines.pt_utils import KeyDataset
8
  from hybrid_textnorm.lexicon import Lexicon
9
  from hybrid_textnorm.normalization import predict_type_normalization, reranked_normalization, prior_normalization
10
  from hybrid_textnorm.preprocess import recombine_tokens, german_transliterate
11
  from tqdm import tqdm
 
 
12
 
13
  text_tokenizer = SoMaJo("de_CMC", split_camel_case=True)
14
  lexicon_dataset_name = 'aehrm/dtaec-lexicon'
@@ -20,78 +23,72 @@ def predict(input_str, model_name, progress=gr.Progress()):
20
  tokenized_sentences = list(text_tokenizer.tokenize_text([input_str]))
21
 
22
  if model_name == 'type normalizer':
23
- output_sentences = predict_only_type_transformer(tokenized_sentences, progress)
24
  elif model_name == 'type normalizer + lm':
25
- output_sentences = predict_type_transformer_with_lm(tokenized_sentences, progress)
26
  elif model_name == 'transnormer':
27
- output_sentences = predict_transnormer(tokenized_sentences, progress)
28
 
29
- if type(output_sentences[0]) == list:
30
- return "\n".join([detok.detokenize(recombine_tokens(sent)) for sent in output_sentences])
31
- else:
32
- return "\n".join(output_sentences)
33
 
34
  def predict_transnormer(tokenized_sentences, progress):
35
  model_name = 'ybracke/transnormer-19c-beta-v02'
36
 
37
- progress(0, desc='loading model')
38
- pipe = pipeline(model=model_name)
 
 
39
 
40
  raw_sentences = []
41
  for tokenized_sent in tokenized_sentences:
42
  sent = ''.join(tok.text + (' ' if tok.space_after else '') for tok in tokenized_sent)
43
- raw_sentences.append(sent)
44
 
 
45
 
46
- progress(0, desc='running normalization')
47
- ds = KeyDataset(Dataset.from_dict(dict(types=list(raw_sentences))), "types")
48
-
49
- output_sentences = []
50
- for out_sentence in progress.tqdm(pipe(ds, num_beams=1, max_length=1024)):
51
- output_sentences.append(out_sentence[0]['generated_text'])
52
-
53
- return output_sentences
54
 
55
 
56
 
57
  def predict_only_type_transformer(tokenized_sentences, progress):
58
  type_model_name = 'aehrm/dtaec-type-normalizer'
59
- progress(0, desc='loading model')
60
-
61
- pipe = pipeline('text2text-generation', type_model_name)
62
 
63
  transliterated_sentences = []
64
  for sentence in tokenized_sentences:
65
- transliterated_sentences.append([german_transliterate(tok.text) for tok in sentence])
66
 
67
- oov_types = set(tok for sent in transliterated_sentences for tok in sent) - train_lexicon.keys()
68
- oov_normalizations = {}
 
 
 
69
 
70
- progress(0, desc='running normalization')
71
- ds = KeyDataset(Dataset.from_dict(dict(types=list(oov_types))), "types")
72
- for in_type, out in zip(ds, progress.tqdm(pipe(ds))):
73
- oov_normalizations[in_type] = out[0]['generated_text']
74
-
75
- output_sentences = []
76
- for sent in transliterated_sentences:
77
  output_sent = []
78
- for t in sent:
79
  if t in train_lexicon.keys():
80
  output_sent.append(train_lexicon[t].most_common(1)[0][0])
81
- elif t in oov_normalizations.keys():
82
- output_sent.append(oov_normalizations[t])
83
  else:
84
  raise ValueError()
85
 
86
- output_sentences.append(output_sent)
87
-
88
- return output_sentences
89
 
90
  def predict_type_transformer_with_lm(tokenized_sentences, progress):
91
  type_model_name = 'aehrm/dtaec-type-normalizer'
92
  language_model_name = 'dbmdz/german-gpt2'
93
 
94
- progress(0, desc='loading model')
95
  type_model_tokenizer = AutoTokenizer.from_pretrained(type_model_name)
96
  type_model = AutoModelForSeq2SeqLM.from_pretrained(type_model_name)
97
  language_model_tokenizer = AutoTokenizer.from_pretrained(language_model_name)
@@ -99,25 +96,18 @@ def predict_type_transformer_with_lm(tokenized_sentences, progress):
99
  if 'pad_token' not in language_model_tokenizer.special_tokens_map:
100
  language_model_tokenizer.add_special_tokens({'pad_token': '<pad>'})
101
 
102
- transliterated_sentences = []
103
- for sentence in tokenized_sentences:
104
- transliterated_sentences.append([german_transliterate(tok.text) for tok in sentence])
105
-
106
- oov_types = set(tok for sent in transliterated_sentences for tok in sent) - train_lexicon.keys()
107
  oov_replacement_probabilities = {}
 
 
 
108
 
109
- progress(0, desc='running normalization')
110
- for input_type, probas in progress.tqdm(predict_type_normalization(oov_types, type_model_tokenizer, type_model, batch_size=8), total=len(oov_types)):
111
- oov_replacement_probabilities[input_type] = probas
112
 
113
- output_sentences = []
114
- progress(0, desc='running LM re-ranking')
115
- for hist_sent in progress.tqdm(transliterated_sentences):
116
- predictions = reranked_normalization(hist_sent, train_lexicon, oov_replacement_probabilities, language_model_tokenizer, language_model, batch_size=1)
117
  best_pred, _, _, _ = predictions[0]
118
- output_sentences.append(best_pred)
119
-
120
- return output_sentences
121
 
122
 
123
  gradio_app = gr.Interface(
 
2
  from nltk.tokenize.treebank import TreebankWordDetokenizer
3
  from somajo import SoMaJo
4
 
5
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextStreamer, TextIteratorStreamer
6
+ from threading import Thread
7
  from datasets import Dataset
8
  from transformers.pipelines.pt_utils import KeyDataset
9
  from hybrid_textnorm.lexicon import Lexicon
10
  from hybrid_textnorm.normalization import predict_type_normalization, reranked_normalization, prior_normalization
11
  from hybrid_textnorm.preprocess import recombine_tokens, german_transliterate
12
  from tqdm import tqdm
13
+ import re
14
+ from collections import Counter
15
 
16
  text_tokenizer = SoMaJo("de_CMC", split_camel_case=True)
17
  lexicon_dataset_name = 'aehrm/dtaec-lexicon'
 
23
  tokenized_sentences = list(text_tokenizer.tokenize_text([input_str]))
24
 
25
  if model_name == 'type normalizer':
26
+ stream = predict_only_type_transformer(tokenized_sentences, progress)
27
  elif model_name == 'type normalizer + lm':
28
+ stream = predict_type_transformer_with_lm(tokenized_sentences, progress)
29
  elif model_name == 'transnormer':
30
+ stream = predict_transnormer(tokenized_sentences, progress)
31
 
32
+ accumulated = ""
33
+ for out in stream:
34
+ accumulated += out
35
+ yield accumulated
36
 
37
  def predict_transnormer(tokenized_sentences, progress):
38
  model_name = 'ybracke/transnormer-19c-beta-v02'
39
 
40
+ #progress(0, desc='loading model')
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ streamer = TextIteratorStreamer(tokenizer)
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
44
 
45
  raw_sentences = []
46
  for tokenized_sent in tokenized_sentences:
47
  sent = ''.join(tok.text + (' ' if tok.space_after else '') for tok in tokenized_sent)
 
48
 
49
+ inputs = tokenizer([sent], return_tensors='pt')
50
 
51
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1000, num_beams=1)
52
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
53
+ thread.start()
54
+ for new_text in streamer:
55
+ yield re.sub(r'(<pad>|</s>)', '', new_text)
56
+ yield '\n'
 
 
57
 
58
 
59
 
60
  def predict_only_type_transformer(tokenized_sentences, progress):
61
  type_model_name = 'aehrm/dtaec-type-normalizer'
62
+ #progress(0, desc='loading model')
63
+ type_model_tokenizer = AutoTokenizer.from_pretrained(type_model_name)
64
+ type_model = AutoModelForSeq2SeqLM.from_pretrained(type_model_name)
65
 
66
  transliterated_sentences = []
67
  for sentence in tokenized_sentences:
68
+ transliterated = [german_transliterate(tok.text) for tok in sentence]
69
 
70
+ oov_replacement_probabilities = {}
71
+ oov_types = set(transliterated) - train_lexicon.keys() - oov_replacement_probabilities.keys()
72
+ #print('oov:', oov_types)
73
+ for input_type, probas in predict_type_normalization(oov_types, type_model_tokenizer, type_model, batch_size=8):
74
+ oov_replacement_probabilities[input_type] = probas
75
 
 
 
 
 
 
 
 
76
  output_sent = []
77
+ for t in transliterated:
78
  if t in train_lexicon.keys():
79
  output_sent.append(train_lexicon[t].most_common(1)[0][0])
80
+ elif t in oov_replacement_probabilities.keys():
81
+ output_sent.append(Counter(dict(oov_replacement_probabilities[t])).most_common(1)[0][0])
82
  else:
83
  raise ValueError()
84
 
85
+ yield detok.detokenize(recombine_tokens(output_sent)) + '\n'
 
 
86
 
87
  def predict_type_transformer_with_lm(tokenized_sentences, progress):
88
  type_model_name = 'aehrm/dtaec-type-normalizer'
89
  language_model_name = 'dbmdz/german-gpt2'
90
 
91
+ #progress(0, desc='loading model')
92
  type_model_tokenizer = AutoTokenizer.from_pretrained(type_model_name)
93
  type_model = AutoModelForSeq2SeqLM.from_pretrained(type_model_name)
94
  language_model_tokenizer = AutoTokenizer.from_pretrained(language_model_name)
 
96
  if 'pad_token' not in language_model_tokenizer.special_tokens_map:
97
  language_model_tokenizer.add_special_tokens({'pad_token': '<pad>'})
98
 
 
 
 
 
 
99
  oov_replacement_probabilities = {}
100
+ for sentence in tokenized_sentences:
101
+ transliterated = [german_transliterate(tok.text) for tok in sentence]
102
+ oov_types = set(transliterated) - train_lexicon.keys() - oov_replacement_probabilities.keys()
103
 
104
+ #print('oov:', oov_types)
105
+ for input_type, probas in predict_type_normalization(oov_types, type_model_tokenizer, type_model, batch_size=8):
106
+ oov_replacement_probabilities[input_type] = probas
107
 
108
+ predictions = reranked_normalization(transliterated, train_lexicon, oov_replacement_probabilities, language_model_tokenizer, language_model, batch_size=1)
 
 
 
109
  best_pred, _, _, _ = predictions[0]
110
+ yield detok.detokenize(recombine_tokens(best_pred)) + '\n'
 
 
111
 
112
 
113
  gradio_app = gr.Interface(