pyakhurel commited on
Commit
3d31898
Β·
1 Parent(s): 0e631ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from peft import PeftModel
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import transformers
6
+
7
+
8
+ adapters_name = "1littlecoder/mistral-7b-mj-finetuned"
9
+ model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded"
10
+ device = "cuda"
11
+
12
+ bnb_config = transformers.BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_use_double_quant=True,
15
+ bnb_4bit_quant_type="nf4",
16
+ bnb_4bit_compute_dtype=torch.bfloat16
17
+ )
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ load_in_4bit=True,
22
+ torch_dtype=torch.bfloat16,
23
+ quantization_config=bnb_config,
24
+ device_map='auto'
25
+ )
26
+ model = PeftModel.from_pretrained(model, adapters_name)
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ tokenizer.bos_token_id = 1
30
+
31
+ stop_token_ids = [0]
32
+
33
+ print(f"Successfully loaded the model {model_name} into memory")
34
+
35
+ def remove_substring(original_string, substring_to_remove):
36
+ # Replace the substring with an empty string
37
+ result_string = original_string.replace(substring_to_remove, '')
38
+ return result_string
39
+
40
+ def list_to_string(input_list, delimiter=" "):
41
+ """
42
+ Convert a list to a string, joining elements with the specified delimiter.
43
+
44
+ :param input_list: The list to convert to a string.
45
+ :param delimiter: The separator to use between elements (default is a space).
46
+ :return: A string composed of list elements separated by the delimiter.
47
+ """
48
+ return delimiter.join(map(str, input_list))
49
+
50
+ def format_prompt(message, history):
51
+ prompt = "<s>"
52
+ for user_prompt, bot_response in history:
53
+ prompt += f"[INST] {user_prompt} [/INST]"
54
+ prompt += f" {bot_response}</s> "
55
+ prompt += f"[INST] {message} [/INST]"
56
+ return prompt
57
+
58
+ def generate(
59
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
60
+ ):
61
+ temperature = float(temperature)
62
+ if temperature < 1e-2:
63
+ temperature = 1e-2
64
+ top_p = float(top_p)
65
+
66
+ generate_kwargs = dict(
67
+ temperature=temperature,
68
+ max_new_tokens=max_new_tokens,
69
+ top_p=top_p,
70
+ repetition_penalty=repetition_penalty,
71
+ do_sample=True,
72
+ seed=42,
73
+ )
74
+
75
+ formatted_prompt = format_prompt(prompt, history)
76
+
77
+ encoded = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False)
78
+ model_input = encoded
79
+ model.to(device)
80
+ generated_ids = model.generate(**model_input, max_new_tokens=200, do_sample=True)
81
+
82
+
83
+ list_output = tokenizer.batch_decode(generated_ids)
84
+ string_output = list_to_string(list_output)
85
+ possible_output = remove_substring(string_output,formatted_prompt)
86
+
87
+ return possible_output
88
+
89
+
90
+ additional_inputs=[
91
+ gr.Slider(
92
+ label="Temperature",
93
+ value=0.9,
94
+ minimum=0.0,
95
+ maximum=1.0,
96
+ step=0.05,
97
+ interactive=True,
98
+ info="Higher values produce more diverse outputs",
99
+ ),
100
+ gr.Slider(
101
+ label="Max new tokens",
102
+ value=256,
103
+ minimum=0,
104
+ maximum=1048,
105
+ step=64,
106
+ interactive=True,
107
+ info="The maximum numbers of new tokens",
108
+ ),
109
+ gr.Slider(
110
+ label="Top-p (nucleus sampling)",
111
+ value=0.90,
112
+ minimum=0.0,
113
+ maximum=1,
114
+ step=0.05,
115
+ interactive=True,
116
+ info="Higher values sample more low-probability tokens",
117
+ ),
118
+ gr.Slider(
119
+ label="Repetition penalty",
120
+ value=1.2,
121
+ minimum=1.0,
122
+ maximum=2.0,
123
+ step=0.05,
124
+ interactive=True,
125
+ info="Penalize repeated tokens",
126
+ )
127
+ ]
128
+
129
+ css = """
130
+ #mkd {
131
+ height: 500px;
132
+ overflow: auto;
133
+ border: 1px solid #ccc;
134
+ }
135
+ """
136
+
137
+ with gr.Blocks(css=css) as demo:
138
+ gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
139
+ gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. πŸ’¬<h3><center>")
140
+ gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. πŸ“š<h3><center>")
141
+ gr.ChatInterface(
142
+ generate,
143
+ additional_inputs=additional_inputs,
144
+ examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
145
+ )
146
+
147
+ demo.queue(concurrency_count=75, max_size=100).launch(debug=True)