PLatonG commited on
Commit
fb5ecd8
1 Parent(s): 4745a53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -0
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Prompter import Prompter
2
+ from Callback import Stream, Iteratorize
3
+ import os
4
+ import sys
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import transformers
9
+ from peft import PeftModel
10
+ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
11
+ import pandas as pd
12
+ import numpy as np
13
+
14
+ if torch.cuda.is_available():
15
+ device = "cuda"
16
+ else:
17
+ device = "cpu"
18
+
19
+ try:
20
+ if torch.backends.mps.is_available():
21
+ device = "mps"
22
+ except: # noqa: E722
23
+ pass
24
+
25
+ base_model = "openthaigpt/openthaigpt-1.0.0-beta-7b-chat-ckpt-hf"
26
+ load_8bit = True
27
+ # lora_weights = "PLatonG/openthaigpt-1.0.0-beta-7b-expert-recommendations"
28
+ lora_weights = "PLatonG/openthaigpt-1.0.0-beta-7b-expert-recommendations"
29
+ prompter = Prompter("alpaca")
30
+ tokenizer = LlamaTokenizer.from_pretrained(base_model)
31
+
32
+ model = LlamaForCausalLM.from_pretrained(
33
+ base_model,
34
+ load_in_8bit=load_8bit,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ offload_folder = "./offload"
38
+ )
39
+ model = PeftModel.from_pretrained(
40
+ model,
41
+ lora_weights,
42
+ torch_dtype=torch.float16,
43
+ offload_folder = "./offload"
44
+ )
45
+
46
+ # unwind broken decapoda-research config
47
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
48
+ model.config.bos_token_id = 1
49
+ model.config.eos_token_id = 2
50
+
51
+ if not load_8bit:
52
+ model.half() # seems to fix bugs for some users.
53
+
54
+ model.eval()
55
+ if torch.__version__ >= "2" and sys.platform != "win32":
56
+ model = torch.compile(model)
57
+
58
+ def evaluate(
59
+ instruction,
60
+ input=None,
61
+ stream_output=False,
62
+ ):
63
+ temperature=0.1
64
+ top_p=0.25
65
+ top_k=30
66
+ num_beams=4
67
+ max_new_tokens=380
68
+
69
+ prompt = prompter.generate_prompt(instruction, input)
70
+ inputs = tokenizer(prompt, return_tensors="pt")
71
+ input_ids = inputs["input_ids"].to(device)
72
+
73
+ generation_config = GenerationConfig(
74
+ temperature=temperature,
75
+ top_p=top_p,
76
+ top_k=top_k,
77
+ num_beams=num_beams,
78
+
79
+ )
80
+ # generation_config = GenerationConfig(
81
+ # do_sample = True,
82
+ # num_beams = 4,
83
+ # )
84
+
85
+ generate_params = {
86
+ "input_ids": input_ids,
87
+ "generation_config": generation_config,
88
+ "return_dict_in_generate": True,
89
+ "output_scores": True,
90
+ "max_new_tokens": max_new_tokens,
91
+ }
92
+
93
+ if stream_output:
94
+ # Stream the reply 1 token at a time.
95
+ # This is based on the trick of using 'stopping_criteria' to create an iterator,
96
+ # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
97
+
98
+ def generate_with_callback(callback=None, **kwargs):
99
+ kwargs.setdefault(
100
+ "stopping_criteria", transformers.StoppingCriteriaList()
101
+ )
102
+ kwargs["stopping_criteria"].append(
103
+ Stream(callback_func=callback)
104
+ )
105
+ with torch.no_grad():
106
+ model.generate(**kwargs)
107
+
108
+ def generate_with_streaming(**kwargs):
109
+ return Iteratorize(
110
+ generate_with_callback, kwargs, callback=None
111
+ )
112
+
113
+ with generate_with_streaming(**generate_params) as generator:
114
+ for output in generator:
115
+ # new_tokens = len(output) - len(input_ids[0])
116
+ decoded_output = tokenizer.decode(output)
117
+
118
+ if output[-1] in [tokenizer.eos_token_id]:
119
+ break
120
+
121
+ yield prompter.get_response(decoded_output)
122
+ return # early return for stream_output
123
+
124
+ # Without streaming
125
+ with torch.no_grad():
126
+ generation_output = model.generate(
127
+ input_ids=input_ids,
128
+ generation_config=generation_config,
129
+ return_dict_in_generate=True,
130
+ output_scores=True,
131
+ max_new_tokens=max_new_tokens,
132
+ )
133
+ s = generation_output.sequences[0]
134
+ output = tokenizer.decode(s)
135
+ yield prompter.get_response(output)
136
+
137
+
138
+ # From SMOTE with 4 neightbor
139
+ fourNSMOTE = pd.read_csv("FILTER_GREATERTHANTHREE_FROM_SHEETS_SMOTE_train.csv")
140
+
141
+ with gr.Blocks(fill_height = True, title="Expert Recommendations") as demo:
142
+ with gr.Row():
143
+ birth_year = gr.components.Number(minimum = 2536, maximum = 2557, value= 2545,
144
+ label="ปีเกิด",
145
+ info="ต่ำสุด : 2536 สูงสุด : 2557")
146
+ nationality_name = gr.components.Dropdown(choices=fourNSMOTE.NATIONALITY_NAME.unique().tolist(),
147
+ label="สัญชาติ",
148
+ value = fourNSMOTE.NATIONALITY_NAME.unique().tolist()[0])
149
+ religion_name = gr.components.Dropdown(choices=fourNSMOTE.RELIGION_NAME.unique().tolist(),
150
+ label="ศาสนา",
151
+ value = fourNSMOTE.RELIGION_NAME.unique().tolist()[0])
152
+ with gr.Row():
153
+ sex = gr.components.Dropdown(choices=fourNSMOTE.JVN_SEX.unique().tolist(),
154
+ label="เพศ",
155
+ value = fourNSMOTE.JVN_SEX.unique().tolist()[0])
156
+ inform_status = gr.components.Dropdown(choices=fourNSMOTE.INFORM_STATUS_TXT.unique().tolist(),
157
+ label="เหตุที่นำมาสู่การดำเนินคดี",
158
+ value = fourNSMOTE.INFORM_STATUS_TXT.unique().tolist()[0])
159
+ age = gr.components.Number(minimum = 10, maximum = 19, value= 17,
160
+ label="อายุตอนกระทำผิด",
161
+ info="ต่ำสุด : 10 ปี สูงสุด : 19")
162
+ with gr.Row():
163
+
164
+ offense_name = gr.components.Dropdown(choices=fourNSMOTE.OFFENSE_NAME.unique().tolist(),
165
+ label="คดีที่กระทำผิด",
166
+ value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0])
167
+
168
+ ref_value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0]
169
+
170
+ allegation_name = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_NAME.unique().tolist(), label="ชื่อของข้อกล่าวหา",
171
+ value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_NAME"].unique().tolist()[0])
172
+
173
+ allegation_desc = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_DESC.unique().tolist(), label="รายละเอียดของข้อกล่าวหา",
174
+ value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_DESC"].unique().tolist()[0])
175
+
176
+ def update_dropDown_allegation(value):
177
+ allegation_query = fourNSMOTE.query("OFFENSE_NAME == @value")
178
+ data = allegation_query["ALLEGATION_NAME"].unique().tolist()
179
+ allegation_name = gr.components.Dropdown(choices=data, value=data[0])
180
+ # allegation_desc = gr.components.Dropdown(choices=query_state["ALLEGATION_DESC"].unique().tolist())
181
+ return allegation_name
182
+
183
+ def update_dropDown_allegation_desc(offense_name, allegation_name):
184
+ allegationDesc_query = fourNSMOTE.query("OFFENSE_NAME == @offense_name and ALLEGATION_NAME == @allegation_name")
185
+ data = allegationDesc_query["ALLEGATION_DESC"].unique().tolist()
186
+ allegation_desc = gr.components.Dropdown(choices=data, value=data[0])
187
+ # allegation_desc = gr.components.Dropdown(choices=query_state["ALLEGATION_DESC"].unique().tolist())
188
+ return allegation_desc
189
+
190
+ offense_name.change(fn=update_dropDown_allegation, inputs=offense_name, outputs=[allegation_name])
191
+ offense_name.change(fn=update_dropDown_allegation_desc, inputs=[offense_name, allegation_name], outputs=[allegation_desc])
192
+ allegation_name.change(fn=update_dropDown_allegation_desc, inputs=[offense_name, allegation_name], outputs=[allegation_desc])
193
+
194
+
195
+
196
+
197
+ with gr.Row():
198
+
199
+ rn1 = gr.components.Radio(choices=["ถูก", "ผิด"],
200
+ label="ปรากฎลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่",
201
+ value="ถูก")
202
+ rn2 = gr.components.Radio(choices=["ถูก", "ผิด"],
203
+ label="ปรากฎประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย",
204
+ value = "ถูก")
205
+ rn3 = gr.components.Radio(choices=["ถูก", "ผิด"],
206
+ label="ปรากฎประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว",
207
+ value = "ถูก")
208
+ with gr.Row():
209
+
210
+ education = gr.components.Dropdown(choices=fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist(),
211
+ label="สถาณะการศึกษา",
212
+ value = fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist()[0])
213
+ occupation = gr.components.Dropdown(choices=fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist(),
214
+ label="สถาณะการประกอบอาชีพ",
215
+ value = fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist()[0])
216
+ province = gr.components.Dropdown(choices=fourNSMOTE.PROVINCE_NAME.unique().tolist(),
217
+ label="จังหวัดที่กระทำผิด",
218
+ value = fourNSMOTE.PROVINCE_NAME.unique().tolist()[0])
219
+
220
+
221
+ def generate_input(birth_year, nationality_name, religion_name, sex,
222
+ inform_status, age, offense_name, allegation_name,
223
+ allegation_desc, rn1, rn2, rn3, education, occupation, province):
224
+
225
+ birth_year = f"เกิดเมื่อปี พ.ศ. {int(birth_year)}"
226
+
227
+ if int(age) >= 10 or int(age) <=15:
228
+ age = f"มีอายุอยู่ในช่วง 10 ถึง 15 ปี"
229
+ elif int(age) >=16 or int(age) <= 20:
230
+ age = f"มีอายุอยู่ในช่วง 16 ถึง 20 ปี"
231
+ elif int(age) >=21 or int(age) <= 25:
232
+ age = f"มีอายุอยู่ในช่วง 21 ถึง 25 ปี"
233
+ elif int(age) >=26:
234
+ age = f"มีอายุอยู่ในช่วง 26 ปีขึ้นไป"
235
+
236
+ if rn1 == "ถูก":
237
+ rn1 = "มีลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่"
238
+ else:
239
+ rn1 = "ไม่มีลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่"
240
+
241
+ if rn2 == "ถูก":
242
+ rn2 = "มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย"
243
+ else:
244
+ rn2 = "ไม่มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย"
245
+
246
+ if rn3 == "ถูก":
247
+ rn3 = "มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว"
248
+ else:
249
+ rn3 = "ไม่มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว"
250
+
251
+ instruciton = "จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้"
252
+ input = f"{birth_year} {nationality_name} {religion_name} {sex} {inform_status} {age} {offense_name} {allegation_name} {allegation_desc} {rn1} {rn2} {rn3} {education} {occupation} {province}"
253
+
254
+
255
+ return input
256
+
257
+ def generate_full_input(inst ,input):
258
+ # output = ["True", "false"]
259
+ # input = np.random.choice(output)
260
+ # input = instruction + " " + input
261
+ # first_element = check[0] # user text
262
+ # last_element = check[-1] # input
263
+ # instruction = check[-2] # instruction
264
+ # input = f"{instruction} {last_element}"
265
+ return f"{inst} {input}"
266
+
267
+ def test_fucn(inst, input, stream):
268
+ return str(inst)
269
+
270
+
271
+
272
+ instruction = gr.Textbox(label = "คำสั่ง", value="จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้", visible=False, interactive=False)
273
+ input_compo = gr.Textbox(label = "ข้อมูลเข้า (input)", show_copy_button = True, visible=False)
274
+
275
+ # stream_output = gr.components.Checkbox(label="Stream output")
276
+
277
+ full_input = gr.Textbox(label = "full prompt", visible=True, show_copy_button=True)
278
+ btn1 = gr.Button("GENERATE INPUT")
279
+ # show input text format for user
280
+ btn1.click(fn=generate_input, inputs=[birth_year, nationality_name, religion_name, sex,
281
+ inform_status, age, offense_name, allegation_name,
282
+ allegation_desc, rn1, rn2, rn3, education, occupation, province],
283
+ outputs=input_compo)
284
+
285
+ # btn1.click(fn=generate_simple_output, inputs = [instruction, input_compo], outputs = full_input)
286
+ input_compo.change(fn = generate_full_input, inputs=[instruction, input_compo], outputs=full_input)
287
+
288
+
289
+ outputModel = gr.Textbox(label= "ผลลัพธ์ (output)")
290
+ btn2 = gr.Button("GENERATE OUTPUT")
291
+ btn2.click(fn=evaluate, inputs=[instruction, input_compo], outputs=outputModel)
292
+
293
+
294
+
295
+ # input text format for model
296
+ # btn.click(fn=generate_text_test2, inputs = [birth_year, nationality_name, religion_name, sex,
297
+ # inform_status, age, offense_name, allegation_name,
298
+ # allegation_desc, rn1, rn2, rn3, education, occupation, province],
299
+ # outputs = input_compo)
300
+
301
+
302
+
303
+ demo.launch(debug=True, share=True)