File size: 4,281 Bytes
9276eae
0d322a6
9276eae
 
 
0fe7c54
0d322a6
 
9276eae
 
 
 
e5475e4
9276eae
 
 
 
 
 
 
e5475e4
9276eae
 
 
0fe7c54
9276eae
 
 
 
 
 
 
0d322a6
9276eae
0d322a6
 
 
 
9276eae
0d322a6
 
 
9276eae
0d322a6
 
 
 
 
0fe7c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9276eae
 
 
0d322a6
 
0fe7c54
0d322a6
 
 
 
 
0fe7c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d322a6
 
 
0fe7c54
 
 
 
 
 
 
 
 
 
 
 
 
0d322a6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import gradio as gr
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers


adapters_name = "1littlecoder/mistral-7b-mj-finetuned"
model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded"
device = "cuda"

bnb_config = transformers.BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map='auto'
)
model = PeftModel.from_pretrained(model, adapters_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.bos_token_id = 1

stop_token_ids = [0]

print(f"Successfully loaded the model {model_name} into memory")

def remove_substring(original_string, substring_to_remove):
    # Replace the substring with an empty string
    result_string = original_string.replace(substring_to_remove, '')
    return result_string

def list_to_string(input_list, delimiter=" "):
    """
    Convert a list to a string, joining elements with the specified delimiter.

    :param input_list: The list to convert to a string.
    :param delimiter: The separator to use between elements (default is a space).
    :return: A string composed of list elements separated by the delimiter.
    """
    return delimiter.join(map(str, input_list))

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(
    prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history)

    encoded = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False)
    model_input = encoded
    model.to(device)
    generated_ids = model.generate(**model_input, max_new_tokens=200, do_sample=True)


    list_output = tokenizer.batch_decode(generated_ids)
    string_output = list_to_string(list_output)
    possible_output = remove_substring(string_output,formatted_prompt)
   
    return possible_output


additional_inputs=[
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

css = """
  #mkd {
    height: 500px;
    overflow: auto;
    border: 1px solid #ccc;
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
    gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. 💬<h3><center>")
    gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. 📚<h3><center>")
    gr.ChatInterface(
        generate,
        additional_inputs=additional_inputs,
        examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
    )

demo.queue(concurrency_count=75, max_size=100).launch(debug=True)