r1208 commited on
Commit
46e037c
·
verified ·
1 Parent(s): 0312d8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -3,7 +3,9 @@ import streamlit as st
3
  from transformers import pipeline
4
  from PIL import Image
5
  import os
6
-
 
 
7
 
8
  def main():
9
 
@@ -16,14 +18,16 @@ def main():
16
  return tokens_list
17
 
18
 
19
- def translate(text, tokenizer, model, bad_words_ids):
20
  # Prepare the prompt
21
- messages = f"Translate from Korean to English: {text}"
 
 
22
  input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
23
  prompt_padded_len = len(input_ids[0])
24
 
25
  # Generate the translation
26
- gen_tokens = model.generate(input_ids, max_length=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, bad_words_ids = bad_words_ids)
27
  gen_tokens = [
28
  gt[prompt_padded_len:] for gt in gen_tokens
29
  ]
@@ -46,9 +50,7 @@ def main():
46
 
47
  hf_token = os.getenv("HF_ACCESS_TOKEN")
48
 
49
- from peft import AutoPeftModelForCausalLM
50
- from transformers import AutoTokenizer
51
- import torch
52
 
53
  # attn_implementation = None
54
  # USE_FLASH_ATTENTION = False
@@ -72,14 +74,16 @@ def main():
72
  temperature = st.sidebar.slider("Temperature", value=0.3, min_value=0.0, max_value=1.0, step=0.05)
73
  top_k = st.sidebar.slider("Top-k", min_value=0, max_value=50, value=0)
74
  top_p = st.sidebar.slider("Top-p", min_value=0.75, max_value=1.0, step=0.05, value=0.9)
 
 
75
 
76
 
77
  st.subheader("Enter text to translate")
78
- input_text = st.text_area("Text to Translate", value= "Korean text here", height=300)
79
 
80
  if st.button("Translate"):
81
  if input_text:
82
- translation = translate(input_text, tokenizer, model, bad_words_ids)
83
  st.text_area("Translated Text", value=translation, height=300)
84
  else:
85
  st.error("Please enter some text to translate.")
 
3
  from transformers import pipeline
4
  from PIL import Image
5
  import os
6
+ from peft import AutoPeftModelForCausalLM
7
+ from transformers import AutoTokenizer
8
+ import torch
9
 
10
  def main():
11
 
 
18
  return tokens_list
19
 
20
 
21
+ def translate(text, tokenizer, model, do_sample, max_new_tokens, temperature, top_k, top_p, bad_words_ids):
22
  # Prepare the prompt
23
+ prompts = f"Translate from Korean to English: {text}"
24
+ messages = [{"role": "user", "content": prompts}]
25
+
26
  input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
27
  prompt_padded_len = len(input_ids[0])
28
 
29
  # Generate the translation
30
+ gen_tokens = model.generate(input_ids, do_sample = do_sample, max_length=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, bad_words_ids = bad_words_ids)
31
  gen_tokens = [
32
  gt[prompt_padded_len:] for gt in gen_tokens
33
  ]
 
50
 
51
  hf_token = os.getenv("HF_ACCESS_TOKEN")
52
 
53
+
 
 
54
 
55
  # attn_implementation = None
56
  # USE_FLASH_ATTENTION = False
 
74
  temperature = st.sidebar.slider("Temperature", value=0.3, min_value=0.0, max_value=1.0, step=0.05)
75
  top_k = st.sidebar.slider("Top-k", min_value=0, max_value=50, value=0)
76
  top_p = st.sidebar.slider("Top-p", min_value=0.75, max_value=1.0, step=0.05, value=0.9)
77
+ do_sample = st.selectbox("do_sample: ",
78
+ ['True', 'False'])
79
 
80
 
81
  st.subheader("Enter text to translate")
82
+ input_text = st.text_area("Text to Translate", value= text_default, height=300)
83
 
84
  if st.button("Translate"):
85
  if input_text:
86
+ translation = translate(input_text, model, do_sample = do_sample, max_new_tokens = max_new_tokens, temperature = temperature, top_k = top_k, top_p = top_p, bad_words_ids = bad_words_ids)
87
  st.text_area("Translated Text", value=translation, height=300)
88
  else:
89
  st.error("Please enter some text to translate.")