wdndev commited on
Commit
dc16d0b
·
1 Parent(s): 996a81c
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import streamlit as st
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers.generation.utils import GenerationConfig
6
+
7
+
8
+ st.set_page_config(page_title="Tiny LLM 92M Demo")
9
+ st.title("Tiny LLM 92M Demo")
10
+
11
+ model_id = "wdndev/tiny_llm_sft_92m"
12
+
13
+ @st.cache_resource
14
+ def load_model_tokenizer():
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
+ device_map="auto",
18
+ trust_remote_code=True
19
+ )
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ model_id,
22
+ use_fast=False,
23
+ trust_remote_code=True
24
+ )
25
+ generation_config = GenerationConfig.from_pretrained(model_id)
26
+ return model, tokenizer, generation_config
27
+
28
+
29
+ def clear_chat_messages():
30
+ del st.session_state.messages
31
+
32
+
33
+ def init_chat_messages():
34
+ with st.chat_message("assistant", avatar='🤖'):
35
+ st.markdown("您好,我是由wdndev开发的个人助手,很高兴为您服务😄")
36
+
37
+ if "messages" in st.session_state:
38
+ for message in st.session_state.messages:
39
+ avatar = "🧑‍💻" if message["role"] == "user" else "🤖"
40
+ with st.chat_message(message["role"], avatar=avatar):
41
+ st.markdown(message["content"])
42
+ else:
43
+ st.session_state.messages = []
44
+
45
+ return st.session_state.messages
46
+
47
+
48
+ max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 1024, 512, step=1)
49
+ top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
50
+ top_k = st.sidebar.slider("top_k", 0, 100, 0, step=1)
51
+ temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step=0.01)
52
+ do_sample = st.sidebar.checkbox("do_sample", value=True)
53
+
54
+ def main():
55
+ model, tokenizer, generation_config = load_model_tokenizer()
56
+ messages = init_chat_messages()
57
+
58
+ if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
59
+ with st.chat_message("user", avatar='🧑‍💻'):
60
+ st.markdown(prompt)
61
+ with st.chat_message("assistant", avatar='🤖'):
62
+ placeholder = st.empty()
63
+
64
+ generation_config.max_new_tokens = max_new_tokens
65
+ generation_config.top_p = top_p
66
+ generation_config.top_k = top_k
67
+ generation_config.temperature = temperature
68
+ generation_config.do_sample = do_sample
69
+ print("generation_config: ", generation_config)
70
+
71
+ sys_text = "你是由wdndev开发的个人助手。"
72
+ messages.append({"role": "user", "content": prompt})
73
+ user_text = prompt
74
+ input_txt = "\n".join(["<|system|>", sys_text.strip(),
75
+ "<|user|>", user_text.strip(),
76
+ "<|assistant|>"]).strip() + "\n"
77
+
78
+ model_inputs = tokenizer(input_txt, return_tensors="pt").to(model.device)
79
+ generated_ids = model.generate(model_inputs.input_ids, generation_config=generation_config)
80
+ generated_ids = [
81
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
82
+ ]
83
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
84
+ placeholder.markdown(response)
85
+
86
+ messages.append({"role": "assistant", "content": response})
87
+ print("messages: ", json.dumps(response, ensure_ascii=False), flush=True)
88
+
89
+ st.button("清空对话", on_click=clear_chat_messages)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()