MohamedRashad commited on
Commit
74d9892
·
1 Parent(s): 890774d

feat: Add Arabic-ORPO-Llama3 chatbot comparison functionality

Browse files
Files changed (1) hide show
  1. app.py +137 -75
app.py CHANGED
@@ -4,106 +4,162 @@ import torch
4
  import gradio as gr
5
  from threading import Thread
6
 
7
- base_model_id = "NousResearch/Meta-Llama-3-8B-Instruct"
8
- new_model_id = "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct"
9
-
10
- # Reload tokenizer and model
11
- tokenizer = AutoTokenizer.from_pretrained(base_model_id)
12
- base_model = AutoModelForCausalLM.from_pretrained(
13
- base_model_id,
14
- torch_dtype=torch.bfloat16,
15
- device_map="auto",
16
- ).eval()
17
- new_model = AutoModelForCausalLM.from_pretrained(
18
- new_model_id,
19
- torch_dtype=torch.bfloat16,
20
- device_map="auto",
21
- ).eval()
22
- terminators = [
23
- tokenizer.eos_token_id,
24
- tokenizer.convert_tokens_to_ids("<|eot_id|>"),
25
  ]
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @spaces.GPU(duration=120)
28
- def generate_both(system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
29
- base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
30
- new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
31
 
32
  system_prompt_list = [{"role": "system", "content": system_prompt}]
33
  input_text_list = [{"role": "user", "content": input_text}]
34
 
35
- base_chat_history = []
36
- for user, assistant in base_chatbot:
37
- base_chat_history.append({"role": "user", "content": user})
38
- base_chat_history.append({"role": "assistant", "content": assistant})
39
 
40
- new_chat_history = []
41
- for user, assistant in new_chatbot:
42
- new_chat_history.append({"role": "user", "content": user})
43
- new_chat_history.append({"role": "assistant", "content": assistant})
44
 
45
- base_messages = system_prompt_list + base_chat_history + input_text_list
46
- new_messages = system_prompt_list + new_chat_history + input_text_list
47
 
48
- base_input_ids = tokenizer.apply_chat_template(
49
  base_messages,
50
  add_generation_prompt=True,
51
  return_tensors="pt"
52
- ).to(base_model.device).long()
53
 
54
- new_input_ids = tokenizer.apply_chat_template(
55
  new_messages,
56
  add_generation_prompt=True,
57
  return_tensors="pt"
58
- ).to(new_model.device).long()
59
 
60
- base_generation_kwargs = dict(
61
- input_ids=base_input_ids,
62
- streamer=base_text_streamer,
63
  max_new_tokens=max_new_tokens,
64
- eos_token_id=terminators,
65
- pad_token_id=tokenizer.eos_token_id,
66
  do_sample=True if temperature > 0 else False,
67
  temperature=temperature,
68
  top_p=top_p,
69
  repetition_penalty=repetition_penalty,
70
  )
71
- new_generation_kwargs = dict(
72
- input_ids=new_input_ids,
73
- streamer=new_text_streamer,
74
  max_new_tokens=max_new_tokens,
75
- eos_token_id=terminators,
76
- pad_token_id=tokenizer.eos_token_id,
77
  do_sample=True if temperature > 0 else False,
78
  temperature=temperature,
79
  top_p=top_p,
80
  repetition_penalty=repetition_penalty,
81
  )
82
 
83
- base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
84
- base_thread.start()
85
-
86
- base_chatbot.append([input_text, ""])
87
- new_chatbot.append([input_text, ""])
88
-
89
- for base_text in base_text_streamer:
90
- if "<|eot_id|>" in base_text:
91
- eot_location = base_text.find("<|eot_id|>")
92
- base_text = base_text[:eot_location]
93
- base_chatbot[-1][-1] += base_text
94
- yield base_chatbot, new_chatbot
95
-
96
- new_thread = Thread(target=new_model.generate, kwargs=new_generation_kwargs)
97
- new_thread.start()
98
-
99
- for new_text in new_text_streamer:
100
- if "<|eot_id|>" in new_text:
101
- eot_location = new_text.find("<|eot_id|>")
102
- new_text = new_text[:eot_location]
103
- new_chatbot[-1][-1] += new_text
104
- yield base_chatbot, new_chatbot
105
-
106
- return base_chatbot, new_chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def clear():
109
  return [], []
@@ -113,8 +169,11 @@ with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
113
  gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
114
  system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
115
  with gr.Row(variant="panel"):
116
- base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
117
- new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
 
 
 
118
  with gr.Row(variant="panel"):
119
  with gr.Column(scale=1):
120
  submit_btn = gr.Button(value="Generate", variant="primary")
@@ -126,8 +185,11 @@ with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
126
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p", step=0.01)
127
  repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1)
128
 
129
- input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
130
- submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
131
- clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
 
 
 
132
 
133
- demo.launch()
 
4
  import gradio as gr
5
  from threading import Thread
6
 
7
+ models_available = [
8
+ "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct",
9
+ "silma-ai/SILMA-9B-Instruct-v0.1.1",
10
+ "inceptionai/jais-adapted-7b-chat",
11
+ "inceptionai/jais-adapted-13b-chat",
12
+ "inceptionai/jais-family-6p7b-chat",
13
+ "inceptionai/jais-family-13b-chat",
 
 
 
 
 
 
 
 
 
 
 
14
  ]
15
 
16
+ # Loop over all models to download them
17
+ for model_id in models_available:
18
+ try:
19
+ AutoTokenizer.from_pretrained(model_id)
20
+ AutoModelForCausalLM.from_pretrained(model_id)
21
+ except:
22
+ models_available.remove(model_id)
23
+
24
+ tokenizer_a, model_a = None, None
25
+ tokenizer_b, model_b = None, None
26
+
27
+ def load_model_a(model_id):
28
+ global tokenizer_a, model_a
29
+ tokenizer_a = AutoTokenizer.from_pretrained(model_id)
30
+ print(f"model A: {tokenizer_a.eos_token}")
31
+ try:
32
+ model_a = AutoModelForCausalLM.from_pretrained(
33
+ model_id,
34
+ torch_dtype=torch.bfloat16,
35
+ device_map="auto",
36
+ attn_implementation="flash_attention_2",
37
+ trust_remote_code=True,
38
+ ).eval()
39
+ except:
40
+ print(f"Using default attention implementation in {model_id}")
41
+ model_a = AutoModelForCausalLM.from_pretrained(
42
+ model_id,
43
+ torch_dtype=torch.bfloat16,
44
+ device_map="auto",
45
+ trust_remote_code=True,
46
+ ).eval()
47
+ return gr.update(label=model_id)
48
+
49
+ def load_model_b(model_id):
50
+ global tokenizer_b, model_b
51
+ tokenizer_b = AutoTokenizer.from_pretrained(model_id)
52
+ print(f"model B: {tokenizer_b.eos_token}")
53
+ try:
54
+ model_b = AutoModelForCausalLM.from_pretrained(
55
+ model_id,
56
+ torch_dtype=torch.bfloat16,
57
+ device_map="auto",
58
+ attn_implementation="flash_attention_2",
59
+ trust_remote_code=True,
60
+ ).eval()
61
+ except:
62
+ print(f"Using default attention implementation in {model_id}")
63
+ model_b = AutoModelForCausalLM.from_pretrained(
64
+ model_id,
65
+ torch_dtype=torch.bfloat16,
66
+ device_map="auto",
67
+ trust_remote_code=True,
68
+ ).eval()
69
+ return gr.update(label=model_id)
70
+
71
  @spaces.GPU(duration=120)
72
+ def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
73
+ text_streamer_a = TextIteratorStreamer(tokenizer_a, skip_prompt=True)
74
+ text_streamer_b = TextIteratorStreamer(tokenizer_b, skip_prompt=True)
75
 
76
  system_prompt_list = [{"role": "system", "content": system_prompt}]
77
  input_text_list = [{"role": "user", "content": input_text}]
78
 
79
+ chat_history_a = []
80
+ for user, assistant in chatbot_a:
81
+ chat_history_a.append({"role": "user", "content": user})
82
+ chat_history_a.append({"role": "assistant", "content": assistant})
83
 
84
+ chat_history_b = []
85
+ for user, assistant in chatbot_b:
86
+ chat_history_b.append({"role": "user", "content": user})
87
+ chat_history_b.append({"role": "assistant", "content": assistant})
88
 
89
+ base_messages = system_prompt_list + chat_history_a + input_text_list
90
+ new_messages = system_prompt_list + chat_history_b + input_text_list
91
 
92
+ input_ids_a = tokenizer_a.apply_chat_template(
93
  base_messages,
94
  add_generation_prompt=True,
95
  return_tensors="pt"
96
+ ).to(model_a.device)
97
 
98
+ input_ids_b = tokenizer_b.apply_chat_template(
99
  new_messages,
100
  add_generation_prompt=True,
101
  return_tensors="pt"
102
+ ).to(model_b.device)
103
 
104
+ generation_kwargs_a = dict(
105
+ input_ids=input_ids_a,
106
+ streamer=text_streamer_a,
107
  max_new_tokens=max_new_tokens,
108
+ pad_token_id=tokenizer_a.eos_token_id,
 
109
  do_sample=True if temperature > 0 else False,
110
  temperature=temperature,
111
  top_p=top_p,
112
  repetition_penalty=repetition_penalty,
113
  )
114
+ generation_kwargs_b = dict(
115
+ input_ids=input_ids_b,
116
+ streamer=text_streamer_b,
117
  max_new_tokens=max_new_tokens,
118
+ pad_token_id=tokenizer_b.eos_token_id,
 
119
  do_sample=True if temperature > 0 else False,
120
  temperature=temperature,
121
  top_p=top_p,
122
  repetition_penalty=repetition_penalty,
123
  )
124
 
125
+ thread_a = Thread(target=model_a.generate, kwargs=generation_kwargs_a)
126
+ thread_b = Thread(target=model_b.generate, kwargs=generation_kwargs_b)
127
+
128
+ thread_a.start()
129
+ thread_b.start()
130
+
131
+ chatbot_a.append([input_text, ""])
132
+ chatbot_b.append([input_text, ""])
133
+
134
+ finished_a = False
135
+ finished_b = False
136
+
137
+ while not (finished_a and finished_b):
138
+ if not finished_a:
139
+ try:
140
+ text_a = next(text_streamer_a)
141
+ if tokenizer_a.eos_token in text_a:
142
+ eot_location = text_a.find(tokenizer_a.eos_token)
143
+ text_a = text_a[:eot_location]
144
+ finished_a = True
145
+ chatbot_a[-1][-1] += text_a
146
+ yield chatbot_a, chatbot_b
147
+ except StopIteration:
148
+ finished_a = True
149
+
150
+ if not finished_b:
151
+ try:
152
+ text_b = next(text_streamer_b)
153
+ if tokenizer_b.eos_token in text_b:
154
+ eot_location = text_b.find(tokenizer_b.eos_token)
155
+ text_b = text_b[:eot_location]
156
+ finished_b = True
157
+ chatbot_b[-1][-1] += text_b
158
+ yield chatbot_a, chatbot_b
159
+ except StopIteration:
160
+ finished_b = True
161
+
162
+ return chatbot_a, chatbot_b
163
 
164
  def clear():
165
  return [], []
 
169
  gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
170
  system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
171
  with gr.Row(variant="panel"):
172
+ model_dropdown_a = gr.Dropdown(label="Model A", choices=models_available, value=None)
173
+ model_dropdown_b = gr.Dropdown(label="Model B", choices=models_available, value=None)
174
+ with gr.Row(variant="panel"):
175
+ chatbot_a = gr.Chatbot(label="Model A", rtl=True, likeable=True, show_copy_button=True, height=500)
176
+ chatbot_b = gr.Chatbot(label="Model B", rtl=True, likeable=True, show_copy_button=True, height=500)
177
  with gr.Row(variant="panel"):
178
  with gr.Column(scale=1):
179
  submit_btn = gr.Button(value="Generate", variant="primary")
 
185
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p", step=0.01)
186
  repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1)
187
 
188
+ model_dropdown_a.change(load_model_a, inputs=[model_dropdown_a], outputs=[chatbot_a])
189
+ model_dropdown_b.change(load_model_b, inputs=[model_dropdown_b], outputs=[chatbot_b])
190
+
191
+ input_text.submit(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b])
192
+ submit_btn.click(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b])
193
+ clear_btn.click(clear, outputs=[chatbot_a, chatbot_b])
194
 
195
+ demo.launch()