wmpscc commited on
Commit
b635f37
·
1 Parent(s): c8d71ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -41,12 +41,13 @@ def init_args():
41
  args = load_hyperparam(args)
42
 
43
  # args.tokenizer = Tokenizer(model_path=args.spm_model_path)
44
- args.tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", trust_remote_code=True)
45
  args.vocab_size = args.tokenizer.sp_model.vocab_size()
46
 
47
 
48
  def init_model():
49
  global lm_generation
 
50
  # torch.set_default_tensor_type(torch.HalfTensor)
51
  # model = LLaMa(args)
52
  # torch.set_default_tensor_type(torch.FloatTensor)
@@ -64,11 +65,12 @@ def init_model():
64
 
65
 
66
  def chat(prompt, top_k, temperature):
67
- args.top_k = int(top_k)
68
- args.temperature = temperature
69
- response = lm_generation.generate(args, [prompt])
70
- print('log:', response[0])
71
- return response[0]
 
72
 
73
 
74
  if __name__ == '__main__':
 
41
  args = load_hyperparam(args)
42
 
43
  # args.tokenizer = Tokenizer(model_path=args.spm_model_path)
44
+ args.tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
45
  args.vocab_size = args.tokenizer.sp_model.vocab_size()
46
 
47
 
48
  def init_model():
49
  global lm_generation
50
+ global model
51
  # torch.set_default_tensor_type(torch.HalfTensor)
52
  # model = LLaMa(args)
53
  # torch.set_default_tensor_type(torch.FloatTensor)
 
65
 
66
 
67
  def chat(prompt, top_k, temperature):
68
+ # args.top_k = int(top_k)
69
+ # args.temperature = temperature
70
+ # response = lm_generation.generate(args, [prompt])
71
+ response = model.chat(args.tokenizer, [prompt])
72
+ print('log:', response)
73
+ return response
74
 
75
 
76
  if __name__ == '__main__':