aehrm commited on
Commit
5af7057
·
1 Parent(s): d2673e5
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -8,10 +8,13 @@ from transformers.pipelines.pt_utils import KeyDataset
8
  from hybrid_textnorm.lexicon import Lexicon
9
  from hybrid_textnorm.normalization import predict_type_normalization, reranked_normalization, prior_normalization
10
  from hybrid_textnorm.preprocess import recombine_tokens, german_transliterate
 
11
 
12
  text_tokenizer = SoMaJo("de_CMC", split_camel_case=True)
13
  lexicon_dataset_name = 'aehrm/dtaec-lexicon'
14
  train_lexicon = Lexicon.from_dataset(lexicon_dataset_name, split='train')
 
 
15
 
16
  def predict(input_str, model_name, progress=gr.Progress()):
17
  tokenized_sentences = list(text_tokenizer.tokenize_text([input_str]))
@@ -24,7 +27,6 @@ def predict(input_str, model_name, progress=gr.Progress()):
24
  output_sentences = predict_transnormer(tokenized_sentences, progress)
25
 
26
  if type(output_sentences[0]) == list:
27
- detok = TreebankWordDetokenizer()
28
  return "\n".join([detok.detokenize(recombine_tokens(sent)) for sent in output_sentences])
29
  else:
30
  return "\n".join(output_sentences)
@@ -32,18 +34,20 @@ def predict(input_str, model_name, progress=gr.Progress()):
32
  def predict_transnormer(tokenized_sentences, progress):
33
  model_name = 'ybracke/transnormer-19c-beta-v02'
34
 
35
- progress(0, desc='running normalization')
36
- pipe = pipeline(model='ybracke/transnormer-19c-beta-v02')
37
 
38
  raw_sentences = []
39
  for tokenized_sent in tokenized_sentences:
40
- raw_sentences.append(''.join(tok.text + (' ' if tok.space_after else '') for tok in tokenized_sent))
 
 
41
 
42
  progress(0, desc='running normalization')
43
  ds = KeyDataset(Dataset.from_dict(dict(types=list(raw_sentences))), "types")
44
 
45
  output_sentences = []
46
- for out_sentence in progress.tqdm(pipe(ds, num_beams=4, max_length=1000)):
47
  output_sentences.append(out_sentence[0]['generated_text'])
48
 
49
  return output_sentences
@@ -107,6 +111,7 @@ def predict_type_transformer_with_lm(tokenized_sentences, progress):
107
  oov_replacement_probabilities[input_type] = probas
108
 
109
  output_sentences = []
 
110
  for hist_sent in progress.tqdm(transliterated_sentences):
111
  predictions = reranked_normalization(hist_sent, train_lexicon, oov_replacement_probabilities, language_model_tokenizer, language_model, batch_size=1)
112
  best_pred, _, _, _ = predictions[0]
 
8
  from hybrid_textnorm.lexicon import Lexicon
9
  from hybrid_textnorm.normalization import predict_type_normalization, reranked_normalization, prior_normalization
10
  from hybrid_textnorm.preprocess import recombine_tokens, german_transliterate
11
+ from tqdm import tqdm
12
 
13
  text_tokenizer = SoMaJo("de_CMC", split_camel_case=True)
14
  lexicon_dataset_name = 'aehrm/dtaec-lexicon'
15
  train_lexicon = Lexicon.from_dataset(lexicon_dataset_name, split='train')
16
+ detok = TreebankWordDetokenizer()
17
+
18
 
19
  def predict(input_str, model_name, progress=gr.Progress()):
20
  tokenized_sentences = list(text_tokenizer.tokenize_text([input_str]))
 
27
  output_sentences = predict_transnormer(tokenized_sentences, progress)
28
 
29
  if type(output_sentences[0]) == list:
 
30
  return "\n".join([detok.detokenize(recombine_tokens(sent)) for sent in output_sentences])
31
  else:
32
  return "\n".join(output_sentences)
 
34
  def predict_transnormer(tokenized_sentences, progress):
35
  model_name = 'ybracke/transnormer-19c-beta-v02'
36
 
37
+ progress(0, desc='loading model')
38
+ pipe = pipeline(model=model_name)
39
 
40
  raw_sentences = []
41
  for tokenized_sent in tokenized_sentences:
42
+ sent = ''.join(tok.text + (' ' if tok.space_after else '') for tok in tokenized_sent)
43
+ raw_sentences.append(sent)
44
+
45
 
46
  progress(0, desc='running normalization')
47
  ds = KeyDataset(Dataset.from_dict(dict(types=list(raw_sentences))), "types")
48
 
49
  output_sentences = []
50
+ for out_sentence in progress.tqdm(pipe(ds, num_beams=4, max_length=1024)):
51
  output_sentences.append(out_sentence[0]['generated_text'])
52
 
53
  return output_sentences
 
111
  oov_replacement_probabilities[input_type] = probas
112
 
113
  output_sentences = []
114
+ progress(0, desc='running LM re-ranking')
115
  for hist_sent in progress.tqdm(transliterated_sentences):
116
  predictions = reranked_normalization(hist_sent, train_lexicon, oov_replacement_probabilities, language_model_tokenizer, language_model, batch_size=1)
117
  best_pred, _, _, _ = predictions[0]