jeffeux commited on
Commit
eca00c9
·
1 Parent(s): d1f2f03
Files changed (1) hide show
  1. app.py +50 -32
app.py CHANGED
@@ -1,10 +1,33 @@
 
1
  import os, logging, torch, streamlit as st
2
  from transformers import (
3
  AutoTokenizer, AutoModelForCausalLM)
4
  st.balloons()
5
 
6
- device = 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
8
  @st.cache
9
  def model_init():
10
  tokenizer = AutoTokenizer.from_pretrained(
@@ -16,38 +39,33 @@ def model_init():
16
  # device_map="auto",
17
  # Ref. for `half`: Chan-Jan, Thanks!
18
  ).eval().to(device)
 
 
19
  return tokenizer, model
20
 
21
  tokenizer, model = model_init()
 
22
 
23
- prompt = '我是'
24
- with torch.no_grad():
25
- [out] = model.generate(
26
- **tokenizer(
27
- prompt, return_tensors="pt"
28
- ).to(device)
29
- )
30
- st.text(out)
31
-
32
- # DONE 6.1s
33
-
34
- # ===== Application Startup at 2023-02-23 17:51:48 =====
35
-
36
- # 2023-02-23 18:52:26.009 INFO matplotlib.font_manager: generated new fontManager
37
-
38
- # Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
39
-
40
-
41
- # You can now view your Streamlit app in your browser.
42
-
43
- # Network URL: http://10.19.49.246:8501
44
- # External URL: http://34.197.127.12:8501
45
-
46
-
47
- # A new version of Streamlit is available.
48
-
49
- # See what's new at https://discuss.streamlit.io/c/announcements
50
-
51
- # Enter the following command to upgrade:
52
- # $ pip install streamlit --upgrade
53
-
 
1
+ # ------------------- LIBRARIES -------------------- #
2
  import os, logging, torch, streamlit as st
3
  from transformers import (
4
  AutoTokenizer, AutoModelForCausalLM)
5
  st.balloons()
6
 
7
+ # --------------------- HELPER --------------------- #
8
+ def C(text, color="yellow"):
9
+ color_dict: dict = dict(
10
+ red="\033[01;31m",
11
+ green="\033[01;32m",
12
+ yellow="\033[01;33m",
13
+ blue="\033[01;34m",
14
+ magenta="\033[01;35m",
15
+ cyan="\033[01;36m",
16
+ )
17
+ color_dict[None] = "\033[0m"
18
+ return (
19
+ f"{color_dict.get(color, None)}"
20
+ f"{text}{color_dict[None]}")
21
+ st.balloons()
22
 
23
+ # ------------------ ENVIORNMENT ------------------- #
24
+ os.environ["HF_ENDPOINT"] = "https://huggingface.co"
25
+ device = ("cuda"
26
+ if torch.cuda.is_available() else "cpu")
27
+ logging.info(C("[INFO] "f"device = {device}"))
28
+ st.balloons()
29
+
30
+ # ------------------ INITITALIZE ------------------- #
31
  @st.cache
32
  def model_init():
33
  tokenizer = AutoTokenizer.from_pretrained(
 
39
  # device_map="auto",
40
  # Ref. for `half`: Chan-Jan, Thanks!
41
  ).eval().to(device)
42
+ st.balloons()
43
+ logging.info(C("[INFO] "f"Model init success!"))
44
  return tokenizer, model
45
 
46
  tokenizer, model = model_init()
47
+ st.balloons()
48
 
49
+ try:
50
+ # ===================== INPUT ====================== #
51
+ # prompt = "\u554F\uFF1A\u53F0\u7063\u6700\u9AD8\u7684\u5EFA\u7BC9\u7269\u662F\uFF1F\u7B54\uFF1A" #@param {type:"string"}
52
+ prompt = st.text_input("Prompt: ")
53
+ st.balloons()
54
+
55
+
56
+ # =================== INFERENCE ==================== #
57
+ if prompt:
58
+ st.balloons()
59
+ with torch.no_grad():
60
+ [texts_out] = model.generate(
61
+ **tokenizer(
62
+ prompt, return_tensors="pt"
63
+ ).to(device))
64
+ st.balloons()
65
+ output_text = tokenizer.decode(texts_out)
66
+ st.balloons()
67
+ st.markdown(output_text)
68
+ st.balloons()
69
+ except Exception as err:
70
+ st.write(str(err))
71
+ st.snow()