test-chat / app.py
lliu01's picture
Change model
a8ba0ae verified
raw
history blame
3.68 kB
import argparse
import os
import spaces
import gradio as gr
import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str) # model path
parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
return parser.parse_args()
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
global model, tokenizer, device
messages = [{'role': 'system', 'content': system_prompt}]
for human, assistant in history:
messages.append({'role': 'user', 'content': human})
messages.append({'role': 'assistant', 'content': assistant})
messages.append({'role': 'user', 'content': message})
problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids
attention_mask = enc.attention_mask
if input_ids.shape[1] > MAX_LENGTH:
input_ids = input_ids[:, -MAX_LENGTH:]
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
do_sample=True,
top_p=0.95,
temperature=temperature,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
use_cache=True,
eos_token_id=100278 # <|im_end|>
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained("lliu01/fortios_cli")
tokenizer = AutoTokenizer.from_pretrained("lliu01/fortios_cli")
model = AutoModelForCausalLM.from_pretrained(
"lliu01/fortios_cli",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
gr.ChatInterface(
predict,
title="FortiOS CLI Chat - Demo",
description="FortiOS CLI Chat",
theme="soft",
chatbot=gr.Chatbot(label="Chat History",),
textbox=gr.Textbox(placeholder="input", container=False, scale=7),
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs=[
gr.Textbox("FortiOS firewall policy configuration.", label="System Prompt"),
gr.Slider(0, 1, 0.5, label="Temperature"),
gr.Slider(100, 2048, 1024, label="Max Tokens"),
],
examples=[
["How can you move a policy by policy ID?"],
["What is the command to enable security profiles in a firewall policy?"],
["How do you configure a service group in the GUI?"],
["How can you configure the firewall policy change summary in the CLI?"],
["How do you disable hardware acceleration for an IPv4 firewall policy in the CLI?"],
["How can you enable WAN optimization in a firewall policy using the CLI?"],
["What are services in FortiOS and how are they used in firewall policies?"],
],
additional_inputs_accordion_name="Parameters",
).queue().launch()