wmpscc commited on
Commit
982c554
·
1 Parent(s): b635f37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -63
app.py CHANGED
@@ -1,81 +1,30 @@
1
- import torch
2
-
3
- import gradio as gr
4
- import argparse
5
- import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- # from transformers import LlamaForCausalLM, LlamaForTokenizer
8
-
9
- from utils import load_hyperparam, load_model
10
- from models.tokenize import Tokenizer
11
- from models.llama import *
12
- from generate import LmGeneration
13
- from huggingface_hub import hf_hub_download
14
-
15
  import os
16
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
17
 
18
- args = None
19
- lm_generation = None
20
-
21
- def init_args():
22
- global args
23
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24
- args = parser.parse_args()
25
- args.load_model_path = 'Linly-AI/ChatFlow-13B'
26
- #args.load_model_path = 'Linly-AI/ChatFlow-7B'
27
- # args.load_model_path = './model_file/chatllama_7b.bin'
28
- #args.config_path = './config/llama_7b.json'
29
- #args.load_model_path = './model_file/chatflow_13b.bin'
30
- args.config_path = './config/llama_13b_config.json'
31
- args.spm_model_path = './model_file/tokenizer.model'
32
- args.batch_size = 1
33
- args.seq_length = 1024
34
- args.world_size = 1
35
- args.use_int8 = True
36
- args.top_p = 0
37
- args.repetition_penalty_range = 1024
38
- args.repetition_penalty_slope = 0
39
- args.repetition_penalty = 1.15
40
-
41
- args = load_hyperparam(args)
42
-
43
- # args.tokenizer = Tokenizer(model_path=args.spm_model_path)
44
- args.tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
45
- args.vocab_size = args.tokenizer.sp_model.vocab_size()
46
 
47
 
48
  def init_model():
49
- global lm_generation
50
- global model
51
- # torch.set_default_tensor_type(torch.HalfTensor)
52
- # model = LLaMa(args)
53
- # torch.set_default_tensor_type(torch.FloatTensor)
54
- # # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
55
- # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
56
- # model = load_model(model, args.load_model_path)
57
- # model.eval()
58
-
59
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
- # model.to(device)
61
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
62
- print(model)
63
- print(torch.cuda.max_memory_allocated() / 1024 ** 3)
64
- lm_generation = LmGeneration(model, args.tokenizer)
65
 
66
 
67
  def chat(prompt, top_k, temperature):
68
- # args.top_k = int(top_k)
69
- # args.temperature = temperature
70
- # response = lm_generation.generate(args, [prompt])
71
- response = model.chat(args.tokenizer, [prompt])
 
72
  print('log:', response)
73
  return response
74
 
75
 
76
  if __name__ == '__main__':
77
- init_args()
78
- init_model()
79
  demo = gr.Interface(
80
  fn=chat,
81
  inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
3
 
4
+ import torch
5
+ import gradio as gr
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  def init_model():
 
 
 
 
 
 
 
 
 
 
 
 
10
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
11
+ tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
12
+ return model, tokenizer
13
+
14
 
15
 
16
  def chat(prompt, top_k, temperature):
17
+ prompt = f"### Instruction:{prompt.strip()} ### Response:"
18
+ inputs = tokenizer(prompt, return_tensors="pt")
19
+ generate_ids = model.generate(inputs.input_ids, max_new_tokens=2048, do_sample = True, top_k=top_k, top_p = 0, temperature=temperature, repetition_penalty=1.15, eos_token_id=2, bos_token_id=1, pad_token_id=0)
20
+ response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
21
+ response = response.lstrip(prompt)
22
  print('log:', response)
23
  return response
24
 
25
 
26
  if __name__ == '__main__':
27
+ model, tokenizer = init_model()
 
28
  demo = gr.Interface(
29
  fn=chat,
30
  inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],