snoop2head commited on
Commit
2beda2c
·
1 Parent(s): d19be82

update with stabilized performance

Browse files
Files changed (1) hide show
  1. app.py +23 -42
app.py CHANGED
@@ -16,7 +16,7 @@ en2ko_model = AutoModelForSeq2SeqLM.from_pretrained("QuoQA-NLP/KE-T5-En2Ko-Base"
16
  st.title("🤖 KoQuillBot")
17
 
18
 
19
- default_value = "한국어 문장 변환기 QuillBot입니다."
20
  src_text = st.text_area(
21
  "바꾸고 싶은 문장을 입력하세요:",
22
  default_value,
@@ -26,59 +26,40 @@ src_text = st.text_area(
26
  print(src_text)
27
 
28
 
29
- def infer_sentence(model, src_text, tokenizer=tokenizer):
30
- encoded_prompt = tokenizer.encode(
31
- src_text,
32
- add_special_tokens=False,
33
- return_tensors="pt",
34
- padding=True,
35
- max_length=64,
36
- )
37
- if encoded_prompt.size()[-1] == 0:
38
- input_ids = None
39
- else:
40
- input_ids = encoded_prompt
41
-
42
- output_sequences = model.generate(
43
- input_ids=input_ids,
44
- max_length=64,
45
- num_beams=5,
46
- repetition_penalty=1.3,
47
- no_repeat_ngram_size=3,
48
- num_return_sequences=1,
49
- )
50
- print(output_sequences)
51
-
52
- generated_sequence = output_sequences[0]
53
- print(generated_sequence)
54
-
55
- # Decode text
56
- text = tokenizer.decode(
57
- generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True
58
- )
59
- print(text)
60
-
61
- # Remove all text after the pad token
62
- stop_token = tokenizer.eos_token
63
- text = text[: text.find(stop_token) if stop_token else None]
64
- text = text.strip()
65
- return text
66
-
67
 
68
  if st.button("문장 변환") or src_text == default_value:
69
  if src_text == "":
70
  st.warning("Please **enter text** for translation")
71
 
72
  else:
73
- english_translation = infer_sentence(
74
- model=ko2en_model, src_text=src_text, tokenizer=tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
 
 
77
  korean_translation = en2ko_model.generate(
78
  **tokenizer(
79
  english_translation,
80
  return_tensors="pt",
81
- padding=True,
 
82
  max_length=64,
83
  ),
84
  max_length=64,
 
16
  st.title("🤖 KoQuillBot")
17
 
18
 
19
+ default_value = "이건 한국어 문장 변환기 QuillBot입니다."
20
  src_text = st.text_area(
21
  "바꾸고 싶은 문장을 입력하세요:",
22
  default_value,
 
26
  print(src_text)
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  if st.button("문장 변환") or src_text == default_value:
31
  if src_text == "":
32
  st.warning("Please **enter text** for translation")
33
 
34
  else:
35
+ # translate into english sentence
36
+ english_translation = ko2en_model.generate(
37
+ **tokenizer(
38
+ src_text,
39
+ return_tensors="pt",
40
+ padding="max_length",
41
+ truncation=True,
42
+ max_length=64,
43
+ ),
44
+ max_length=64,
45
+ num_beams=5,
46
+ repetition_penalty=1.3,
47
+ no_repeat_ngram_size=3,
48
+ num_return_sequences=1,
49
+ )
50
+ english_translation = tokenizer.decode(
51
+ english_translation[0],
52
+ clean_up_tokenization_spaces=True,
53
+ skip_special_tokens=True,
54
  )
55
 
56
+ # translate back to korean
57
  korean_translation = en2ko_model.generate(
58
  **tokenizer(
59
  english_translation,
60
  return_tensors="pt",
61
+ padding="max_length",
62
+ truncation=True,
63
  max_length=64,
64
  ),
65
  max_length=64,