Last commit not found
import argparse | |
import os | |
import gradio as gr | |
import mdtex2html | |
from gradio.themes.utils import colors, fonts, sizes | |
import torch | |
from peft import PeftModel | |
from transformers import ( | |
AutoModel, | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
BloomForCausalLM, | |
BloomTokenizerFast, | |
LlamaTokenizer, | |
LlamaForCausalLM, | |
GenerationConfig, | |
) | |
MODEL_CLASSES = { | |
"bloom": (BloomForCausalLM, BloomTokenizerFast), | |
"chatglm": (AutoModel, AutoTokenizer), | |
"llama": (LlamaForCausalLM, LlamaTokenizer), | |
"auto": (AutoModelForCausalLM, AutoTokenizer), | |
} | |
class OpenGVLab(gr.themes.base.Base): | |
def __init__( | |
self, | |
*, | |
primary_hue=colors.blue, | |
secondary_hue=colors.sky, | |
neutral_hue=colors.gray, | |
spacing_size=sizes.spacing_md, | |
radius_size=sizes.radius_sm, | |
text_size=sizes.text_md, | |
font=( | |
fonts.GoogleFont("Noto Sans"), | |
"ui-sans-serif", | |
"sans-serif", | |
), | |
font_mono=( | |
fonts.GoogleFont("IBM Plex Mono"), | |
"ui-monospace", | |
"monospace", | |
), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
text_size=text_size, | |
font=font, | |
font_mono=font_mono, | |
) | |
super().set( | |
body_background_fill="*neutral_50", | |
) | |
gvlabtheme = OpenGVLab(primary_hue=colors.blue, | |
secondary_hue=colors.sky, | |
neutral_hue=colors.gray, | |
spacing_size=sizes.spacing_md, | |
radius_size=sizes.radius_sm, | |
text_size=sizes.text_md, | |
) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_type', default="llama", type=str) | |
parser.add_argument('--base_model', default=r"/data/wangpeng/JiaotongGPT-main/merged-sft-no-1ep", type=str) | |
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model") | |
parser.add_argument('--tokenizer_path', default=None, type=str) | |
parser.add_argument('--gpus', default="0", type=str) | |
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference') | |
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings') | |
args = parser.parse_args() | |
if args.only_cpu is True: | |
args.gpus = "" | |
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus | |
def postprocess(self, y): | |
if y is None: | |
return [] | |
for i, (message, response) in enumerate(y): | |
y[i] = ( | |
None if message is None else mdtex2html.convert((message)), | |
None if response is None else mdtex2html.convert(response), | |
) | |
return y | |
gr.Chatbot.postprocess = postprocess | |
generation_config = dict( | |
temperature=0.2, | |
top_k=40, | |
top_p=0.9, | |
do_sample=True, | |
num_beams=1, | |
repetition_penalty=1.1, | |
max_new_tokens=400 | |
) | |
load_type = torch.float16 | |
if torch.cuda.is_available(): | |
device = torch.device(0) | |
else: | |
device = torch.device('cpu') | |
if args.tokenizer_path is None: | |
args.tokenizer_path = args.base_model | |
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | |
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True) | |
base_model = model_class.from_pretrained( | |
args.base_model, | |
load_in_8bit=False, | |
torch_dtype=load_type, | |
low_cpu_mem_usage=True, | |
device_map='auto', | |
trust_remote_code=True, | |
) | |
if args.resize_emb: | |
model_vocab_size = base_model.get_input_embeddings().weight.size(0) | |
tokenzier_vocab_size = len(tokenizer) | |
print(f"Vocab of the base model: {model_vocab_size}") | |
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") | |
if model_vocab_size != tokenzier_vocab_size: | |
print("Resize model embeddings to fit tokenizer") | |
base_model.resize_token_embeddings(tokenzier_vocab_size) | |
if args.lora_model: | |
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto') | |
print("loaded lora model") | |
else: | |
model = base_model | |
if device == torch.device('cpu'): | |
model.float() | |
model.eval() | |
def reset_user_input(): | |
return gr.update(value='') | |
def reset_state(): | |
return [], [] | |
def generate_prompt(instruction): | |
return f"""You are TransGPT, a specialist in the field of transportation.Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
### Instruction: | |
{instruction} | |
### Response: """ | |
def predict( | |
input, | |
chatbot, | |
history, | |
max_new_tokens=128, | |
top_p=0.75, | |
temperature=0.1, | |
top_k=40, | |
num_beams=4, | |
repetition_penalty=1.0, | |
max_memory=256, | |
**kwargs, | |
): | |
now_input = input | |
chatbot.append((input, "")) | |
history = history or [] | |
if len(history) != 0: | |
input = "".join( | |
["### Instruction:\n" + i[0] + "\n\n" + "### Response: " + i[1] + "\n\n" for i in history]) + \ | |
"### Instruction:\n" + input | |
input = input[len("### Instruction:\n"):] | |
if len(input) > max_memory: | |
input = input[-max_memory:] | |
prompt = generate_prompt(input) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(device) | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
with torch.no_grad(): | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=False, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=float(repetition_penalty), | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s, skip_special_tokens=True) | |
output = output.split("### Response:")[-1].strip() | |
history.append((now_input, output)) | |
chatbot[-1] = (now_input, output) | |
return chatbot, history | |
title = """<h1 align="center">Welcome to TransGPT!""" | |
with gr.Blocks(title="DUOMO TransGPT!", theme=gvlabtheme, | |
css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: | |
gr.Markdown(title) | |
# with gr.Blocks() as demo: | |
# gr.HTML("""<h1 align="center">TransGPT</h1>""") | |
# # gr.Markdown( | |
# # "> 为了促进医疗行业大模型的开放研究,本项目开源了TransGPT医疗大模型") | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Column(scale=12): | |
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( | |
container=False) | |
with gr.Column(min_width=32, scale=1): | |
submitBtn = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=1): | |
emptyBtn = gr.Button("Clear History") | |
max_length = gr.Slider( | |
0, 4096, value=128, step=1.0, label="Maximum length", interactive=True) | |
top_p = gr.Slider(0, 1, value=0.8, step=0.01, | |
label="Top P", interactive=True) | |
temperature = gr.Slider( | |
0, 1, value=0.7, step=0.01, label="Temperature", interactive=True) | |
history = gr.State([]) # (message, bot_message) | |
submitBtn.click(predict, [user_input, chatbot, history, max_length, top_p, temperature], [chatbot, history], | |
show_progress=True) | |
submitBtn.click(reset_user_input, [], [user_input]) | |
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) | |
demo.queue().launch(share=True, inbrowser=True, server_name='0.0.0.0', server_port=8080) | |
if __name__ == '__main__': | |
main() | |