kevinwang676 commited on
Commit
056a2f6
·
1 Parent(s): c467392

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -4,14 +4,30 @@ import mdtex2html
4
  import torch
5
  import os
6
 
7
- #CHECKPOINT_PATH=f'output_lh/checkpoint-600'
 
8
  tokenizer = AutoTokenizer.from_pretrained("chatglm3-6b", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, device='cuda')
11
 
 
 
 
12
  model = model.eval()
13
 
14
 
 
15
  """Override Chatbot.postprocess"""
16
 
17
 
 
4
  import torch
5
  import os
6
 
7
+
8
+ CHECKPOINT_PATH=f'output_lh_v2/checkpoint-700'
9
  tokenizer = AutoTokenizer.from_pretrained("chatglm3-6b", trust_remote_code=True)
10
+ config = AutoConfig.from_pretrained("chatglm3-6b", trust_remote_code=True, pre_seq_len=128)
11
+ model = AutoModel.from_pretrained("chatglm3-6b", config=config, trust_remote_code=True)
12
+ prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"), map_location=torch.device('cpu'))
13
+
14
+ #prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
15
+
16
+ new_prefix_state_dict = {}
17
+ for k, v in prefix_state_dict.items():
18
+ if k.startswith("transformer.prefix_encoder."):
19
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
20
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
21
 
22
+ #model = model.half().cuda()
23
 
24
+ model = model.float()
25
+
26
+ model.transformer.prefix_encoder.float()
27
  model = model.eval()
28
 
29
 
30
+
31
  """Override Chatbot.postprocess"""
32
 
33