xiaoxishui commited on
Commit
1559fe0
·
1 Parent(s): 728a621

Update transformers to 4.39.3 and optimize model loading

Browse files
Files changed (2) hide show
  1. app.py +9 -7
  2. requirements.txt +1 -1
app.py CHANGED
@@ -36,6 +36,7 @@ model_name_or_path = "xiaoxishui/internlm2_5-7b-chat"
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
  print(f"Using device: {device}")
38
 
 
39
  @dataclass
40
  class GenerationConfig:
41
  # this config is used for chat to provide more diversity
@@ -187,7 +188,10 @@ def on_btn_click():
187
  def load_model():
188
  model = (AutoModelForCausalLM.from_pretrained(
189
  model_name_or_path,
190
- trust_remote_code=True).to(torch.bfloat16).cuda())
 
 
 
191
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
192
  trust_remote_code=True)
193
  return model, tokenizer
@@ -210,17 +214,16 @@ def prepare_generation_config():
210
  return generation_config
211
 
212
 
213
- user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
214
- robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
215
- cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
216
- <|im_start|>assistant\n'
217
 
218
 
219
  def combine_history(prompt):
220
  messages = st.session_state.messages
221
  meta_instruction = ('You are a helpful, honest, '
222
  'and harmless AI assistant.')
223
- total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
224
  for message in messages:
225
  cur_content = message['content']
226
  if message['role'] == 'user':
@@ -293,4 +296,3 @@ def main():
293
 
294
  if __name__ == '__main__':
295
  main()
296
-
 
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
  print(f"Using device: {device}")
38
 
39
+
40
  @dataclass
41
  class GenerationConfig:
42
  # this config is used for chat to provide more diversity
 
188
  def load_model():
189
  model = (AutoModelForCausalLM.from_pretrained(
190
  model_name_or_path,
191
+ trust_remote_code=True,
192
+ use_cache=False, # 禁用 KV 缓存
193
+ torch_dtype=torch.bfloat16,
194
+ device_map="auto")).cuda()
195
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
196
  trust_remote_code=True)
197
  return model, tokenizer
 
214
  return generation_config
215
 
216
 
217
+ user_prompt = '👥\n{user}\n'
218
+ robot_prompt = '🤖\n{robot}\n'
219
+ cur_query_prompt = '👥\n{user}\n'
 
220
 
221
 
222
  def combine_history(prompt):
223
  messages = st.session_state.messages
224
  meta_instruction = ('You are a helpful, honest, '
225
  'and harmless AI assistant.')
226
+ total_prompt = f'<s>🤖\n{meta_instruction}\n'
227
  for message in messages:
228
  cur_content = message['content']
229
  if message['role'] == 'user':
 
296
 
297
  if __name__ == '__main__':
298
  main()
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  streamlit>=1.8.0
2
- transformers==4.36.0
3
  torch>=2.0.0
4
  accelerate>=0.20.0
5
  sentencepiece
 
1
  streamlit>=1.8.0
2
+ transformers==4.39.3
3
  torch>=2.0.0
4
  accelerate>=0.20.0
5
  sentencepiece