snoop2head commited on
Commit
ec0b50d
ยท
1 Parent(s): 2a9d3fa

update: code

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +65 -22
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .python-version
app.py CHANGED
@@ -16,40 +16,77 @@ en2ko_model = AutoModelForSeq2SeqLM.from_pretrained("QuoQA-NLP/KE-T5-En2Ko-Base"
16
  st.title("๐Ÿค– KoQuillBot")
17
 
18
 
19
- default_value = "์•ˆ๋…•ํ•˜์„ธ์š”. ์ €๋Š” ๋ฌธ์žฅ์„ ๋‹ค์‹œ ์ž‘์„ฑํ•ด์ฃผ๋Š” KoQuillBot์ž…๋‹ˆ๋‹ค."
20
  src_text = st.text_area(
21
  "๋ฐ”๊พธ๊ณ  ์‹ถ์€ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”:",
22
  default_value,
23
  height=50,
24
  max_chars=200,
25
  )
 
26
 
 
 
 
 
 
 
 
 
27
 
28
- if st.button("๋ฌธ์žฅ ๋ณ€ํ™˜"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if src_text == "":
30
  st.warning("Please **enter text** for translation")
31
 
32
  else:
33
- translated = ko2en_model.generate(
34
- **tokenizer(
35
- [src_text],
36
- return_tensors="pt",
37
- padding=True,
38
- max_length=64,
39
- ),
40
- max_length=64,
41
- num_beams=5,
42
- repetition_penalty=1.3,
43
- no_repeat_ngram_size=3,
44
- num_return_sequences=1,
45
  )
46
 
47
- list_translated = [
48
- tokenizer.decode(t, skip_special_tokens=True) for t in translated
49
- ]
50
- backtranslated = en2ko_model.generate(
51
  **tokenizer(
52
- list_translated,
53
  return_tensors="pt",
54
  padding=True,
55
  max_length=64,
@@ -60,10 +97,16 @@ if st.button("๋ฌธ์žฅ ๋ณ€ํ™˜"):
60
  no_repeat_ngram_size=3,
61
  num_return_sequences=1,
62
  )
 
 
 
 
 
 
 
63
  else:
64
  pass
65
 
66
 
67
- print([tokenizer.decode(t, skip_special_tokens=True) for t in backtranslated])
68
-
69
- st.write([tokenizer.decode(t, skip_special_tokens=True) for t in backtranslated][0])
 
16
  st.title("๐Ÿค– KoQuillBot")
17
 
18
 
19
+ default_value = "ํ•œ๊ตญ์–ด ๋ฌธ์žฅ ๋ณ€ํ™˜๊ธฐ QuillBot์ž…๋‹ˆ๋‹ค."
20
  src_text = st.text_area(
21
  "๋ฐ”๊พธ๊ณ  ์‹ถ์€ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”:",
22
  default_value,
23
  height=50,
24
  max_chars=200,
25
  )
26
+ print(src_text)
27
 
28
+ # num_beams = st.sidebar.slider(
29
+ # "Num Beams", min_value=5, max_value=10, value=5
30
+ # ) # https://huggingface.co/blog/constrained-beam-search
31
+ # temperature = st.sidebar.slider(
32
+ # "Temperature", value=0.9, min_value=0.0, max_value=1.0, step=0.05
33
+ # )
34
+ # top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0)
35
+ # top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=1.0)
36
 
37
+
38
+ def infer_sentence(model, src_text, tokenizer=tokenizer):
39
+ encoded_prompt = tokenizer.encode(
40
+ src_text,
41
+ add_special_tokens=False,
42
+ return_tensors="pt",
43
+ padding=True,
44
+ max_length=64,
45
+ )
46
+ if encoded_prompt.size()[-1] == 0:
47
+ input_ids = None
48
+ else:
49
+ input_ids = encoded_prompt
50
+
51
+ output_sequences = model.generate(
52
+ input_ids=input_ids,
53
+ max_length=64,
54
+ num_beams=5,
55
+ repetition_penalty=1.3,
56
+ no_repeat_ngram_size=3,
57
+ num_return_sequences=1,
58
+ )
59
+ print(output_sequences)
60
+
61
+ generated_sequence = output_sequences[0]
62
+ print(generated_sequence)
63
+
64
+ # Decode text
65
+ text = tokenizer.decode(
66
+ generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True
67
+ )
68
+ print(text)
69
+
70
+ # Remove all text after the pad token
71
+ stop_token = tokenizer.eos_token
72
+ text = text[: text.find(stop_token) if stop_token else None]
73
+ text = text.strip()
74
+ return text
75
+
76
+
77
+ if st.button("๋ฌธ์žฅ ๋ณ€ํ™˜") or src_text == default_value:
78
  if src_text == "":
79
  st.warning("Please **enter text** for translation")
80
 
81
  else:
82
+ st.success("Translating...")
83
+ english_translation = infer_sentence(
84
+ model=ko2en_model, src_text=src_text, tokenizer=tokenizer
 
 
 
 
 
 
 
 
 
85
  )
86
 
87
+ korean_translation = en2ko_model.generate(
 
 
 
88
  **tokenizer(
89
+ english_translation,
90
  return_tensors="pt",
91
  padding=True,
92
  max_length=64,
 
97
  no_repeat_ngram_size=3,
98
  num_return_sequences=1,
99
  )
100
+
101
+ korean_translation = tokenizer.decode(
102
+ korean_translation[0],
103
+ clean_up_tokenization_spaces=True,
104
+ skip_special_tokens=True,
105
+ )
106
+ st.success(f"{src_text} -> {english_translation} -> {korean_translation}")
107
  else:
108
  pass
109
 
110
 
111
+ st.write(korean_translation)
112
+ print(korean_translation)