yuhaofeng-shiba commited on
Commit
05abaae
·
1 Parent(s): 1ae602b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -6,6 +6,7 @@ from utils import load_hyperparam, load_model
6
  from models.tokenize import Tokenizer
7
  from models.llama import *
8
  from generate import LmGeneration
 
9
 
10
  import os
11
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
@@ -17,15 +18,16 @@ def init_args():
17
  global args
18
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19
  args = parser.parse_args()
20
- args.load_model_path = './model_file/chatllama_7b.bin'
21
- args.config_path = './config/llama_7b.json'
 
22
  #args.load_model_path = './model_file/chatflow_13b.bin'
23
- #args.config_path = './config/llama_13b_config.json'
24
  args.spm_model_path = './model_file/tokenizer.model'
25
  args.batch_size = 1
26
  args.seq_length = 1024
27
  args.world_size = 1
28
- args.use_int8 = False
29
  args.top_p = 0
30
  args.repetition_penalty_range = 1024
31
  args.repetition_penalty_slope = 0
@@ -42,6 +44,7 @@ def init_model():
42
  torch.set_default_tensor_type(torch.HalfTensor)
43
  model = LLaMa(args)
44
  torch.set_default_tensor_type(torch.FloatTensor)
 
45
  model = load_model(model, args.load_model_path)
46
  model.eval()
47
 
 
6
  from models.tokenize import Tokenizer
7
  from models.llama import *
8
  from generate import LmGeneration
9
+ from huggingface_hub import hf_hub_download
10
 
11
  import os
12
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
 
18
  global args
19
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
  args = parser.parse_args()
21
+ args.load_model_path = 'Linly-AI/ChatFlow-13B'
22
+ # args.load_model_path = './model_file/chatllama_7b.bin'
23
+ # args.config_path = './config/llama_7b.json'
24
  #args.load_model_path = './model_file/chatflow_13b.bin'
25
+ args.config_path = './config/llama_13b_config.json'
26
  args.spm_model_path = './model_file/tokenizer.model'
27
  args.batch_size = 1
28
  args.seq_length = 1024
29
  args.world_size = 1
30
+ args.use_int8 = True
31
  args.top_p = 0
32
  args.repetition_penalty_range = 1024
33
  args.repetition_penalty_slope = 0
 
44
  torch.set_default_tensor_type(torch.HalfTensor)
45
  model = LLaMa(args)
46
  torch.set_default_tensor_type(torch.FloatTensor)
47
+ args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
48
  model = load_model(model, args.load_model_path)
49
  model.eval()
50