File size: 3,677 Bytes
829da7c
 
31a1ff8
e2b5fc2
829da7c
 
54fe16b
829da7c
 
 
 
 
 
 
54fe16b
829da7c
 
 
 
 
54fe16b
31a1ff8
54fe16b
829da7c
 
54fe16b
829da7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53b40bf
54fe16b
 
829da7c
a8ba0ae
 
ca3ac1a
a8ba0ae
ca3ac1a
 
 
829da7c
 
54fe16b
 
a8ba0ae
 
54fe16b
829da7c
54fe16b
 
 
 
 
a8ba0ae
829da7c
54fe16b
 
994685c
a8ba0ae
 
 
 
 
 
 
994685c
54fe16b
53b40bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import os
import spaces

import gradio as gr

import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str)  # model path
    parser.add_argument("--n_gpus", type=int, default=1)  # n_gpu
    return parser.parse_args()

@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
    global model, tokenizer, device
    messages = [{'role': 'system', 'content': system_prompt}]
    for human, assistant in history:
        messages.append({'role': 'user', 'content': human})
        messages.append({'role': 'assistant', 'content': assistant})
    messages.append({'role': 'user', 'content': message})
    problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
    stop_tokens = ["<|endoftext|>", "<|im_end|>"]
    streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
    input_ids = enc.input_ids
    attention_mask = enc.attention_mask

    if input_ids.shape[1] > MAX_LENGTH:
        input_ids = input_ids[:, -MAX_LENGTH:]

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    generate_kwargs = dict(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        streamer=streamer,
        do_sample=True,
        top_p=0.95,
        temperature=temperature,
        max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
        use_cache=True,
        eos_token_id=100278 # <|im_end|>
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)



if __name__ == "__main__":
    args = parse_args()
    tokenizer = AutoTokenizer.from_pretrained("lliu01/fortios_cli")
    tokenizer = AutoTokenizer.from_pretrained("lliu01/fortios_cli")
    model = AutoModelForCausalLM.from_pretrained(
        "lliu01/fortios_cli",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    gr.ChatInterface(
        predict,
        title="FortiOS CLI Chat - Demo",
        description="FortiOS CLI Chat",
        theme="soft",
        chatbot=gr.Chatbot(label="Chat History",),
        textbox=gr.Textbox(placeholder="input", container=False, scale=7),
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs=[
            gr.Textbox("FortiOS firewall policy configuration.", label="System Prompt"),
            gr.Slider(0, 1, 0.5, label="Temperature"),
            gr.Slider(100, 2048, 1024, label="Max Tokens"),
        ],
        examples=[
            ["How can you move a policy by policy ID?"],
            ["What is the command to enable security profiles in a firewall policy?"],
            ["How do you configure a service group in the GUI?"],
            ["How can you configure the firewall policy change summary in the CLI?"],
            ["How do you disable hardware acceleration for an IPv4 firewall policy in the CLI?"],
            ["How can you enable WAN optimization in a firewall policy using the CLI?"],
            ["What are services in FortiOS and how are they used in firewall policies?"],
        ],
        additional_inputs_accordion_name="Parameters",
    ).queue().launch()