Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,81 +1,30 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
-
import argparse
|
5 |
-
import torch
|
6 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
-
# from transformers import LlamaForCausalLM, LlamaForTokenizer
|
8 |
-
|
9 |
-
from utils import load_hyperparam, load_model
|
10 |
-
from models.tokenize import Tokenizer
|
11 |
-
from models.llama import *
|
12 |
-
from generate import LmGeneration
|
13 |
-
from huggingface_hub import hf_hub_download
|
14 |
-
|
15 |
import os
|
16 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
def init_args():
|
22 |
-
global args
|
23 |
-
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
24 |
-
args = parser.parse_args()
|
25 |
-
args.load_model_path = 'Linly-AI/ChatFlow-13B'
|
26 |
-
#args.load_model_path = 'Linly-AI/ChatFlow-7B'
|
27 |
-
# args.load_model_path = './model_file/chatllama_7b.bin'
|
28 |
-
#args.config_path = './config/llama_7b.json'
|
29 |
-
#args.load_model_path = './model_file/chatflow_13b.bin'
|
30 |
-
args.config_path = './config/llama_13b_config.json'
|
31 |
-
args.spm_model_path = './model_file/tokenizer.model'
|
32 |
-
args.batch_size = 1
|
33 |
-
args.seq_length = 1024
|
34 |
-
args.world_size = 1
|
35 |
-
args.use_int8 = True
|
36 |
-
args.top_p = 0
|
37 |
-
args.repetition_penalty_range = 1024
|
38 |
-
args.repetition_penalty_slope = 0
|
39 |
-
args.repetition_penalty = 1.15
|
40 |
-
|
41 |
-
args = load_hyperparam(args)
|
42 |
-
|
43 |
-
# args.tokenizer = Tokenizer(model_path=args.spm_model_path)
|
44 |
-
args.tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
|
45 |
-
args.vocab_size = args.tokenizer.sp_model.vocab_size()
|
46 |
|
47 |
|
48 |
def init_model():
|
49 |
-
global lm_generation
|
50 |
-
global model
|
51 |
-
# torch.set_default_tensor_type(torch.HalfTensor)
|
52 |
-
# model = LLaMa(args)
|
53 |
-
# torch.set_default_tensor_type(torch.FloatTensor)
|
54 |
-
# # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
|
55 |
-
# args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
|
56 |
-
# model = load_model(model, args.load_model_path)
|
57 |
-
# model.eval()
|
58 |
-
|
59 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
-
# model.to(device)
|
61 |
model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
|
66 |
|
67 |
def chat(prompt, top_k, temperature):
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
response =
|
|
|
72 |
print('log:', response)
|
73 |
return response
|
74 |
|
75 |
|
76 |
if __name__ == '__main__':
|
77 |
-
|
78 |
-
init_model()
|
79 |
demo = gr.Interface(
|
80 |
fn=chat,
|
81 |
inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
3 |
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def init_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
|
12 |
+
return model, tokenizer
|
13 |
+
|
14 |
|
15 |
|
16 |
def chat(prompt, top_k, temperature):
|
17 |
+
prompt = f"### Instruction:{prompt.strip()} ### Response:"
|
18 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
19 |
+
generate_ids = model.generate(inputs.input_ids, max_new_tokens=2048, do_sample = True, top_k=top_k, top_p = 0, temperature=temperature, repetition_penalty=1.15, eos_token_id=2, bos_token_id=1, pad_token_id=0)
|
20 |
+
response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
21 |
+
response = response.lstrip(prompt)
|
22 |
print('log:', response)
|
23 |
return response
|
24 |
|
25 |
|
26 |
if __name__ == '__main__':
|
27 |
+
model, tokenizer = init_model()
|
|
|
28 |
demo = gr.Interface(
|
29 |
fn=chat,
|
30 |
inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
|