File size: 3,063 Bytes
113dbd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import time
import argparse
from vllm import LLM, SamplingParams


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)  # model path
    parser.add_argument("--n_gpu", type=int, default=1)  # n_gpu
    return parser.parse_args()

def echo(message, history, system_prompt, temperature, max_tokens):
    response = f"System prompt: {system_prompt}\n Message: {message}. \n Temperature: {temperature}. \n Max Tokens: {max_tokens}."
    for i in range(min(len(response), int(max_tokens))):
        time.sleep(0.05)
        yield response[: i+1]

def predict(message, history, system_prompt, temperature, max_tokens):
    instruction = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
    for human, assistant in history:
        instruction += 'USER: '+ human + ' ASSISTANT: '+ assistant + '</s>'
    instruction += 'USER: '+ message + ' ASSISTANT:'
    problem = [instruction]
    stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
    sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens)
    completions = llm.generate(problem, sampling_params)
    for output in completions:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        for idx in range(len(generated_text)):
                yield generated_text[:idx+1]


"""
- Setup environment:
```bash
conda create -n wizardweb python=3.8 -y
conda activate wizardweb
pip install vllm
pip install transformers==4.31.0
pip install --upgrade gradio
pip install jsonlines
pip install ray==2.5.1
```
```python
python gradio_wizardlm.py --model xxxx --n_gpu 1
python gradio_wizardlm.py --model /workspaceblobstore/caxu/trained_models/13Bv2_v14continue_2048_e3_2e_5/checkpoint-850 --n_gpu 1
```

"""
if __name__ == "__main__":
    args = parse_args()
    llm = LLM(model=args.model, tensor_parallel_size=args.n_gpu)

    gr.ChatInterface(
        predict,
        title="LLM playground - WizardLM",
        description="This is a LLM playground for WizardLM.",
        theme="soft",
        # examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"],
        # cache_examples=False,
        chatbot=gr.Chatbot(height=300, label="Chat History",),
        textbox=gr.Textbox(placeholder="input", container=False, scale=7),
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs=[
            gr.Textbox("You are helpful AI.", label="System Prompt"), 
            gr.Slider(0, 1, 0.9, label="Temperature"),
            gr.Slider(10, 1000, 800, label="Max Tokens"),
        ],
        additional_inputs_accordion_name="Parameters",
    ).queue().launch(share=False, server_name="phlrr2019.guest.corp.microsoft.com", server_port=7860)
    # ).queue().launch(share=False, server_name="0.0.0.0", server_port=7860)