snoop2head commited on
Commit
869d2d6
·
1 Parent(s): 0dc4a45

fix: decoding function

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -23,16 +23,16 @@ if st.button('문장 변환'):
23
 
24
  else:
25
  translated = ko2en_model.generate(
26
- **tokenizer(src_text, return_tensors="pt", padding=True, max_length=64,),
27
  max_length=64,
28
  num_beams=5,
29
  repetition_penalty=1.3,
30
  no_repeat_ngram_size=3,
31
  num_return_sequences=1,
32
  )
33
-
34
  backtranslated = en2ko_model.generate(
35
- **tokenizer(translated, return_tensors="pt", padding=True, max_length=64,),
36
  max_length=64,
37
  num_beams=5,
38
  repetition_penalty=1.3,
@@ -43,6 +43,6 @@ else:
43
  pass
44
 
45
 
46
- print(backtranslated)
47
 
48
- st.write(backtranslated)
 
23
 
24
  else:
25
  translated = ko2en_model.generate(
26
+ **tokenizer([src_text], return_tensors="pt", padding=True, max_length=64,),
27
  max_length=64,
28
  num_beams=5,
29
  repetition_penalty=1.3,
30
  no_repeat_ngram_size=3,
31
  num_return_sequences=1,
32
  )
33
+ list_translated = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
34
  backtranslated = en2ko_model.generate(
35
+ **tokenizer(list_translated, return_tensors="pt", padding=True, max_length=64,),
36
  max_length=64,
37
  num_beams=5,
38
  repetition_penalty=1.3,
 
43
  pass
44
 
45
 
46
+ print([tokenizer.decode(t, skip_special_tokens=True) for t in backtranslated])
47
 
48
+ st.write([tokenizer.decode(t, skip_special_tokens=True) for t in backtranslated][0])