File size: 3,799 Bytes
2f0b879
 
 
 
bebd6a0
2f0b879
 
 
 
bebd6a0
2f0b879
 
5a8ff4e
bebd6a0
2f0b879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4dec50
2f0b879
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import logging

from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

log_level = os.environ.get("LOG_LEVEL", "WARNING")
logging.basicConfig(encoding='utf-8', level=log_level)

logging.info("Creating Inference Client")
client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)

def format_prompt(message, history):
    """Formats the prompt for the AI"""
    logging.info("Formatting Prompt")
    logging.debug("Input Message: %s", message)
    logging.debug("Input History: %s", history)

    prompt = "<|im_start|>system\n" +\
        "You are Dolphin, a helpful AI assistant.<|im_end|>"
    prompt += "<|im_start|>user\n" + f"{message}<|im_end|>"
    prompt += "<|im_start|>assistant"

    return prompt


def generate(
    prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    logging.info("Generating Response")
    logging.debug("Input Prompt: %s", prompt)
    logging.debug("Input History: %s", history)
    logging.debug("Input System Prompt: %s", system_prompt)
    logging.debug("Input Temperature: %s", temperature)
    logging.debug("Input Max New Tokens: %s", max_new_tokens)
    logging.debug("Input Top P: %s", top_p)
    logging.debug("Input Repetition Penalty: %s", repetition_penalty)

    logging.info("Converting Parameters to Correct Type")
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    logging.debug("Temperature: %s", temperature)
    logging.debug("Top P: %s", top_p)

    logging.info("Creating Generate kwargs")
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    logging.debug("Generate Args: %s", generate_kwargs)

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    logging.debug("Prompt: %s", formatted_prompt)

    logging.info("Generating Text")
    stream = client.text_generation(
        formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)

    logging.info("Creating Output")
    output = ""
    for response in stream:
        output += response.token.text
        yield output

    logging.debug("Output: %s", output)
    return output


additional_inputs = [
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

examples = []

logging.info("Creating Chat Interface")
gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False,
                       show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral Instruct",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False)