File size: 2,771 Bytes
9c9ed59
 
 
 
7f7d37c
9c9ed59
 
 
 
 
 
 
 
 
 
 
 
ca677a9
9c9ed59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca677a9
9c9ed59
 
 
 
 
 
 
 
 
 
ca677a9
 
 
 
 
9c9ed59
 
5d41bd3
9c9ed59
 
 
 
c8c7772
9c9ed59
 
 
5d41bd3
9c9ed59
28d0e79
9c9ed59
 
c8c7772
9c9ed59
 
 
 
 
 
 
 
c8c7772
9c9ed59
 
 
 
 
 
 
 
c8c7772
9c9ed59
 
 
64eb5c5
1afe06d
9c9ed59
e95e8e1
 
 
2891dae
f8e42d0
1afe06d
 
e95e8e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)


def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(
    prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.2,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Более высокое значение, даёт более разнообразные результаты.",
    ),
    gr.Slider(
        label="Max new tokens",
        value=16512,
        minimum=0,
        maximum=32768,
        step=64,
        interactive=True,
        info="Максимальное количество токенов",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Более высокое значение, даёт большее разнообразие ",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Степень наказания за повторение токенов",
    )
]

examples=[["", "Отвечай всегда полностью на русском языке", 0.2, 16512, 0.90, 1.2],
         ]

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mix-OpenAI-Chat",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False)