Kit-Hung commited on
Commit
b44b961
·
1 Parent(s): f94f54f
Files changed (4) hide show
  1. README.md +5 -3
  2. app.py +295 -0
  3. download.py +9 -0
  4. requirements.txt +6 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
  title: Intern Assistant
3
- emoji: 🌖
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Intern Assistant
3
+ emoji: 📈
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: intern assistant
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script refers to the dialogue example of streamlit, the interactive
2
+ generation code of chatglm2 and transformers.
3
+
4
+ We mainly modified part of the code logic to adapt to the
5
+ generation of our model.
6
+ Please refer to these links below for more information:
7
+ 1. streamlit chat example:
8
+ https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
9
+ 2. chatglm2:
10
+ https://github.com/THUDM/ChatGLM2-6B
11
+ 3. transformers:
12
+ https://github.com/huggingface/transformers
13
+ Please run with the command `streamlit run path/to/web_demo.py
14
+ --server.address=0.0.0.0 --server.port 7860`.
15
+ Using `python path/to/web_demo.py` may cause unknown problems.
16
+ """
17
+ # isort: skip_file
18
+ import copy
19
+ import warnings
20
+ from dataclasses import asdict, dataclass
21
+ from typing import Callable, List, Optional
22
+
23
+ import streamlit as st
24
+ import torch
25
+ from torch import nn
26
+ from transformers.generation.utils import (LogitsProcessorList,
27
+ StoppingCriteriaList)
28
+ from transformers.utils import logging
29
+
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
31
+ from download import download_assist_tuner
32
+
33
+ logger = logging.get_logger(__name__)
34
+ model_name_or_path="/home/user/assistTuner"
35
+
36
+ download_assist_tuner(model_name_or_path)
37
+
38
+ @dataclass
39
+ class GenerationConfig:
40
+ # this config is used for chat to provide more diversity
41
+ max_length: int = 32768
42
+ top_p: float = 0.8
43
+ temperature: float = 0.8
44
+ do_sample: bool = True
45
+ repetition_penalty: float = 1.005
46
+
47
+
48
+ @torch.inference_mode()
49
+ def generate_interactive(
50
+ model,
51
+ tokenizer,
52
+ prompt,
53
+ generation_config: Optional[GenerationConfig] = None,
54
+ logits_processor: Optional[LogitsProcessorList] = None,
55
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
56
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
57
+ List[int]]] = None,
58
+ additional_eos_token_id: Optional[int] = None,
59
+ **kwargs,
60
+ ):
61
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
62
+ input_length = len(inputs['input_ids'][0])
63
+ for k, v in inputs.items():
64
+ inputs[k] = v.cuda()
65
+ input_ids = inputs['input_ids']
66
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
67
+ if generation_config is None:
68
+ generation_config = model.generation_config
69
+ generation_config = copy.deepcopy(generation_config)
70
+ model_kwargs = generation_config.update(**kwargs)
71
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
72
+ generation_config.bos_token_id,
73
+ generation_config.eos_token_id,
74
+ )
75
+ if isinstance(eos_token_id, int):
76
+ eos_token_id = [eos_token_id]
77
+ if additional_eos_token_id is not None:
78
+ eos_token_id.append(additional_eos_token_id)
79
+ has_default_max_length = kwargs.get(
80
+ 'max_length') is None and generation_config.max_length is not None
81
+ if has_default_max_length and generation_config.max_new_tokens is None:
82
+ warnings.warn(
83
+ f"Using 'max_length''s default \
84
+ ({repr(generation_config.max_length)}) \
85
+ to control the generation length. "
86
+ 'This behaviour is deprecated and will be removed from the \
87
+ config in v5 of Transformers -- we'
88
+ ' recommend using `max_new_tokens` to control the maximum \
89
+ length of the generation.',
90
+ UserWarning,
91
+ )
92
+ elif generation_config.max_new_tokens is not None:
93
+ generation_config.max_length = generation_config.max_new_tokens + \
94
+ input_ids_seq_length
95
+ if not has_default_max_length:
96
+ logger.warn( # pylint: disable=W4902
97
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
98
+ f"and 'max_length'(={generation_config.max_length}) seem to "
99
+ "have been set. 'max_new_tokens' will take precedence. "
100
+ 'Please refer to the documentation for more information. '
101
+ '(https://huggingface.co/docs/transformers/main/'
102
+ 'en/main_classes/text_generation)',
103
+ UserWarning,
104
+ )
105
+
106
+ if input_ids_seq_length >= generation_config.max_length:
107
+ input_ids_string = 'input_ids'
108
+ logger.warning(
109
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
110
+ f"but 'max_length' is set to {generation_config.max_length}. "
111
+ 'This can lead to unexpected behavior. You should consider'
112
+ " increasing 'max_new_tokens'.")
113
+
114
+ # 2. Set generation parameters if not already defined
115
+ logits_processor = logits_processor if logits_processor is not None \
116
+ else LogitsProcessorList()
117
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
118
+ else StoppingCriteriaList()
119
+
120
+ logits_processor = model._get_logits_processor(
121
+ generation_config=generation_config,
122
+ input_ids_seq_length=input_ids_seq_length,
123
+ encoder_input_ids=input_ids,
124
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
125
+ logits_processor=logits_processor,
126
+ )
127
+
128
+ stopping_criteria = model._get_stopping_criteria(
129
+ generation_config=generation_config,
130
+ stopping_criteria=stopping_criteria)
131
+ logits_warper = model._get_logits_warper(generation_config)
132
+
133
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
134
+ scores = None
135
+ while True:
136
+ model_inputs = model.prepare_inputs_for_generation(
137
+ input_ids, **model_kwargs)
138
+ # forward pass to get next token
139
+ outputs = model(
140
+ **model_inputs,
141
+ return_dict=True,
142
+ output_attentions=False,
143
+ output_hidden_states=False,
144
+ )
145
+
146
+ next_token_logits = outputs.logits[:, -1, :]
147
+
148
+ # pre-process distribution
149
+ next_token_scores = logits_processor(input_ids, next_token_logits)
150
+ next_token_scores = logits_warper(input_ids, next_token_scores)
151
+
152
+ # sample
153
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
154
+ if generation_config.do_sample:
155
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
156
+ else:
157
+ next_tokens = torch.argmax(probs, dim=-1)
158
+
159
+ # update generated ids, model inputs, and length for next step
160
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
161
+ model_kwargs = model._update_model_kwargs_for_generation(
162
+ outputs, model_kwargs, is_encoder_decoder=False)
163
+ unfinished_sequences = unfinished_sequences.mul(
164
+ (min(next_tokens != i for i in eos_token_id)).long())
165
+
166
+ output_token_ids = input_ids[0].cpu().tolist()
167
+ output_token_ids = output_token_ids[input_length:]
168
+ for each_eos_token_id in eos_token_id:
169
+ if output_token_ids[-1] == each_eos_token_id:
170
+ output_token_ids = output_token_ids[:-1]
171
+ response = tokenizer.decode(output_token_ids)
172
+
173
+ yield response
174
+ # stop when each sentence is finished
175
+ # or if we exceed the maximum length
176
+ if unfinished_sequences.max() == 0 or stopping_criteria(
177
+ input_ids, scores):
178
+ break
179
+
180
+
181
+ def on_btn_click():
182
+ del st.session_state.messages
183
+
184
+
185
+ @st.cache_resource
186
+ def load_model():
187
+ model = (AutoModelForCausalLM.from_pretrained(
188
+ model_name_or_path,
189
+ trust_remote_code=True).to(torch.bfloat16).cuda())
190
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
191
+ trust_remote_code=True)
192
+ return model, tokenizer
193
+
194
+
195
+ def prepare_generation_config():
196
+ with st.sidebar:
197
+ max_length = st.slider('Max Length',
198
+ min_value=8,
199
+ max_value=32768,
200
+ value=32768)
201
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
202
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
203
+ st.button('Clear Chat History', on_click=on_btn_click)
204
+
205
+ generation_config = GenerationConfig(max_length=max_length,
206
+ top_p=top_p,
207
+ temperature=temperature)
208
+
209
+ return generation_config
210
+
211
+
212
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
213
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
214
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
215
+ <|im_start|>assistant\n'
216
+
217
+
218
+ def combine_history(prompt):
219
+ messages = st.session_state.messages
220
+ meta_instruction = ('You are a helpful, honest, '
221
+ 'and harmless AI assistant.')
222
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
223
+ for message in messages:
224
+ cur_content = message['content']
225
+ if message['role'] == 'user':
226
+ cur_prompt = user_prompt.format(user=cur_content)
227
+ elif message['role'] == 'robot':
228
+ cur_prompt = robot_prompt.format(robot=cur_content)
229
+ else:
230
+ raise RuntimeError
231
+ total_prompt += cur_prompt
232
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
233
+ return total_prompt
234
+
235
+
236
+ def main():
237
+ st.title('internlm2_5-7b-chat-assistant')
238
+
239
+ # torch.cuda.empty_cache()
240
+ print('load model begin.')
241
+ model, tokenizer = load_model()
242
+ print('load model end.')
243
+
244
+ generation_config = prepare_generation_config()
245
+
246
+ # Initialize chat history
247
+ if 'messages' not in st.session_state:
248
+ st.session_state.messages = []
249
+
250
+ # Display chat messages from history on app rerun
251
+ for message in st.session_state.messages:
252
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
253
+ st.markdown(message['content'])
254
+
255
+ # Accept user input
256
+ if prompt := st.chat_input('What is up?'):
257
+ # Display user message in chat message container
258
+
259
+ with st.chat_message('user', avatar='user'):
260
+
261
+ st.markdown(prompt)
262
+ real_prompt = combine_history(prompt)
263
+ # Add user message to chat history
264
+ st.session_state.messages.append({
265
+ 'role': 'user',
266
+ 'content': prompt,
267
+ 'avatar': 'user'
268
+ })
269
+
270
+ with st.chat_message('robot', avatar='assistant'):
271
+
272
+ message_placeholder = st.empty()
273
+ for cur_response in generate_interactive(
274
+ model=model,
275
+ tokenizer=tokenizer,
276
+ prompt=real_prompt,
277
+ additional_eos_token_id=92542,
278
+ device='cuda:0',
279
+ **asdict(generation_config),
280
+ ):
281
+ # Display robot response in chat message container
282
+ message_placeholder.markdown(cur_response + '▌')
283
+ message_placeholder.markdown(cur_response)
284
+ # Add robot response to chat history
285
+ st.session_state.messages.append({
286
+ 'role': 'robot',
287
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
288
+ 'avatar': 'assistant',
289
+ })
290
+ torch.cuda.empty_cache()
291
+
292
+
293
+ if __name__ == '__main__':
294
+ main()
295
+
download.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def download_assist_tuner(model_name_or_path):
4
+ # 设置环境变量
5
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
6
+
7
+ # 下载模型
8
+ os.system(
9
+ f'mkdir -p {model_name_or_path} && huggingface-cli download --resume-download KitHung/internlm2-chat-1_8b_assistant --local-dir {model_name_or_path}')
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit==1.40.1
2
+ transformers==4.39.0
3
+ torch==2.4.1
4
+ torchvision==0.19.1
5
+ torchaudio==2.4.1
6
+ einops==0.8.0