File size: 1,708 Bytes
b87a3ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import TYPE_CHECKING, Dict, Optional, Tuple

import gradio as gr

if TYPE_CHECKING:
    from gradio.blocks import Block
    from gradio.components import Component
    from llmtuner.webui.chat import WebChatModel


def create_chat_box(
    chat_model: "WebChatModel",
    visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
    with gr.Box(visible=visible) as chat_box:
        chatbot = gr.Chatbot()

        with gr.Row():
            with gr.Column(scale=4):
                system = gr.Textbox(show_label=False)
                query = gr.Textbox(show_label=False, lines=8)
                submit_btn = gr.Button(variant="primary")

            with gr.Column(scale=1):
                clear_btn = gr.Button()
                max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
                top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
                temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)

    history = gr.State([])

    submit_btn.click(
        chat_model.predict,
        [chatbot, query, history, system, max_new_tokens, top_p, temperature],
        [chatbot, history],
        show_progress=True
    ).then(
        lambda: gr.update(value=""), outputs=[query]
    )

    clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)

    return chat_box, chatbot, history, dict(
        system=system,
        query=query,
        submit_btn=submit_btn,
        clear_btn=clear_btn,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        temperature=temperature
    )