davidberenstein1957 HF staff commited on
Commit
e466982
Β·
1 Parent(s): a74ee73

Update chat interface with feedback Dani

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +42 -39
  3. chat_interface_preference.py +17 -21
  4. requirements.txt +7 -5
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: LLM Human Feedback Collector Chat Interface (DPO)
3
  emoji: 🦾πŸ’ͺ🏽
4
  colorFrom: pink
5
  colorTo: blue
@@ -12,4 +12,4 @@ suggested_hardware: t4-small
12
  short_description: LLM, chatbot
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LLM Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO)
3
  emoji: 🦾πŸ’ͺ🏽
4
  colorFrom: pink
5
  colorTo: blue
 
12
  short_description: LLM, chatbot
13
  ---
14
 
15
+ Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
app.py CHANGED
@@ -1,6 +1,5 @@
1
  #!/usr/bin/env python
2
  import os
3
- from threading import Thread
4
  from typing import Iterator
5
 
6
  import gradio as gr
@@ -15,10 +14,14 @@ DEFAULT_MAX_NEW_TOKENS = 1024
15
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
16
 
17
  if torch.cuda.is_available():
18
- model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
19
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
20
- tokenizer = AutoTokenizer.from_pretrained(model_id)
21
- style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
 
 
 
 
22
 
23
 
24
  @spaces.GPU
@@ -31,36 +34,39 @@ def generate(
31
  top_k: int = 40,
32
  repetition_penalty: float = 1.2,
33
  ) -> Iterator[str]:
34
- conversation = []
35
- for user, assistant in chat_history:
36
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
37
- conversation.append({"role": "user", "content": message})
38
-
39
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
40
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
41
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
42
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
43
- input_ids = input_ids.to(model.device)
44
-
45
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
46
- generate_kwargs = dict(
47
- {"input_ids": input_ids},
48
- streamer=streamer,
49
- max_new_tokens=max_new_tokens,
50
- do_sample=True,
51
- top_p=top_p,
52
- top_k=top_k,
53
- temperature=temperature,
54
- num_beams=1,
55
- repetition_penalty=repetition_penalty,
56
- )
57
- t = Thread(target=model.generate, kwargs=generate_kwargs)
58
- t.start()
59
-
60
- outputs = []
61
- for text in streamer:
62
- outputs.append(text)
63
- yield "".join(outputs)
 
 
 
64
 
65
 
66
  chat_interface = ChatInterface(
@@ -69,10 +75,7 @@ chat_interface = ChatInterface(
69
  min_turns=1,
70
  max_turns=10,
71
  repo_id="llm-human-feedback-collector-chat-interface-dpo",
72
- chatbot=gr.Chatbot(
73
- height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True
74
- ),
75
- css=style,
76
  cache_examples=False,
77
  additional_inputs=[
78
  gr.Slider(
 
1
  #!/usr/bin/env python
2
  import os
 
3
  from typing import Iterator
4
 
5
  import gradio as gr
 
14
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
15
 
16
  if torch.cuda.is_available():
17
+ # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
18
+ # model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
19
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+ pass
21
+
22
+ style = None
23
+
24
+ AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
25
 
26
 
27
  @spaces.GPU
 
34
  top_k: int = 40,
35
  repetition_penalty: float = 1.2,
36
  ) -> Iterator[str]:
37
+ # conversation = []
38
+ # for user, assistant in chat_history:
39
+ # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
+ # conversation.append({"role": "user", "content": message})
41
+
42
+ # input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
43
+ # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
44
+ # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
45
+ # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
46
+ # input_ids = input_ids.to(model.device)
47
+
48
+ # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
49
+ # generate_kwargs = dict(
50
+ # {"input_ids": input_ids},
51
+ # streamer=streamer,
52
+ # max_new_tokens=max_new_tokens,
53
+ # do_sample=True,
54
+ # top_p=top_p,
55
+ # top_k=top_k,
56
+ # temperature=temperature,
57
+ # num_beams=1,
58
+ # repetition_penalty=repetition_penalty,
59
+ # )
60
+ # t = Thread(target=model.generate, kwargs=generate_kwargs)
61
+ # t.start()
62
+
63
+ # outputs = []
64
+ # for text in streamer:
65
+ # outputs.append(text)
66
+ # yield "".join(outputs)
67
+
68
+ for char in "help":
69
+ yield "help"
70
 
71
 
72
  chat_interface = ChatInterface(
 
75
  min_turns=1,
76
  max_turns=10,
77
  repo_id="llm-human-feedback-collector-chat-interface-dpo",
78
+ chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
 
 
 
79
  cache_examples=False,
80
  additional_inputs=[
81
  gr.Slider(
chat_interface_preference.py CHANGED
@@ -144,15 +144,15 @@ class ChatInterface(Blocks):
144
  submit_btn_bad = None
145
  stop_btn = "Stop"
146
  undo_btn = "↩️ Undo"
147
- clear_btn = "πŸ—‘οΈ Log and clear"
148
  if "kto" in prefence_techniques:
149
- submit_btn_good = "Log response πŸ‘"
150
- submit_btn_bad = "Log response πŸ‘Ž"
151
  if any([technique for technique in ["dpo", "simpo", "rlhf", "orpo"] if technique in self.prefence_techniques]):
152
- submit_btn_two = "Generate 2"
153
- submit_btn_a = "Log preference πŸ…°οΈ"
154
- submit_btn_b = "Log preference πŸ…±οΈ"
155
- submit_btn_ab = "Continue random πŸ…°οΈ=πŸ…±οΈ"
156
  super().__init__(
157
  analytics_enabled=analytics_enabled,
158
  mode="chat_interface",
@@ -387,13 +387,13 @@ class ChatInterface(Blocks):
387
 
388
  def _setup_events(self) -> None:
389
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
390
- submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=1)
391
  submit_triggers_one = (
392
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
393
  )
394
  submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
395
  if self.submit_btn_two:
396
- submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
397
  submit_triggers_two = [self.submit_btn_two.click]
398
  submit_tuples.append((submit_fn_two, submit_triggers_two))
399
  for _fn, _triggers in submit_tuples:
@@ -608,7 +608,7 @@ class ChatInterface(Blocks):
608
  if turn[-1]:
609
  conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
610
 
611
- return self.css + "<body>" + conversation + "</body>"
612
 
613
  def _get_conversation_in_openai_format(self, history):
614
  conversation = []
@@ -622,26 +622,22 @@ class ChatInterface(Blocks):
622
 
623
  @staticmethod
624
  def _get_chat_message(message, role, turn):
625
- if role == "user":
626
- justify = "right"
627
- else:
628
- justify = "left"
629
  return (
630
- f'<div class="{role}-message" style="justify-content: {justify};">'
631
- + '<div class="message-content">'
632
- + f"<strong>Turn {turn} - {role.capitalize()}:</strong><br>"
633
  + f"<em>Length: {len(message)} characters</em><br><br>"
634
  + f'<div class="message-identifier">{message}</div>'
635
- + "</div></div>"
636
  )
637
 
638
  def _get_chat_message_comparison(self, content_a, content_b):
639
  return (
640
- '<div class="container">'
641
- + '<div class="column">'
642
  + self._get_chat_message(message=content_a, role="system", turn="A")
643
  + "</div>"
644
- + '<div class="column">'
645
  + self._get_chat_message(message=content_b, role="system", turn="B")
646
  + "</div>"
647
  + "</div>"
 
144
  submit_btn_bad = None
145
  stop_btn = "Stop"
146
  undo_btn = "↩️ Undo"
147
+ clear_btn = "πŸ—‘οΈ Clear"
148
  if "kto" in prefence_techniques:
149
+ submit_btn_good = "The response πŸ‘"
150
+ submit_btn_bad = "The response πŸ‘Ž"
151
  if any([technique for technique in ["dpo", "simpo", "rlhf", "orpo"] if technique in self.prefence_techniques]):
152
+ submit_btn_two = None
153
+ submit_btn_a = "A is better than B"
154
+ submit_btn_b = "B is better than A"
155
+ submit_btn_ab = "A and B are similar"
156
  super().__init__(
157
  analytics_enabled=analytics_enabled,
158
  mode="chat_interface",
 
387
 
388
  def _setup_events(self) -> None:
389
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
390
+ submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=2)
391
  submit_triggers_one = (
392
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
393
  )
394
  submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
395
  if self.submit_btn_two:
396
+ submit_fn_two = functools.partial(submit_fn_one, n_generations=1)
397
  submit_triggers_two = [self.submit_btn_two.click]
398
  submit_tuples.append((submit_fn_two, submit_triggers_two))
399
  for _fn, _triggers in submit_tuples:
 
608
  if turn[-1]:
609
  conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
610
 
611
+ return "<body>" + self.css + conversation + "</body>"
612
 
613
  def _get_conversation_in_openai_format(self, history):
614
  conversation = []
 
622
 
623
  @staticmethod
624
  def _get_chat_message(message, role, turn):
625
+ # return f"<p><div class='message-identifier'>{message}</div></p>"
 
 
 
626
  return (
627
+ '<div class="message-content">'
628
+ + f"<strong>Option {turn} - </strong><br>"
 
629
  + f"<em>Length: {len(message)} characters</em><br><br>"
630
  + f'<div class="message-identifier">{message}</div>'
631
+ + "</div>"
632
  )
633
 
634
  def _get_chat_message_comparison(self, content_a, content_b):
635
  return (
636
+ '<div class="container" style="display: flex; width: 100%;">'
637
+ + '<div class="column" style="flex: 1; padding: 10px;">'
638
  + self._get_chat_message(message=content_a, role="system", turn="A")
639
  + "</div>"
640
+ + '<div class="column" style="flex: 1; padding: 10px;">'
641
  + self._get_chat_message(message=content_b, role="system", turn="B")
642
  + "</div>"
643
  + "</div>"
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
- accelerate==0.31.0
2
- bitsandbytes==0.42
3
- gradio==4.36.1
4
  scipy==1.13.0
5
- sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.1
8
- transformers==4.41.2
 
 
 
 
 
 
1
+ gradio==4.39
 
 
2
  scipy==1.13.0
 
3
  spaces==0.28.3
4
  torch==2.0.1
5
+ accelerate
6
+ bitsandbytes
7
+ torch
8
+ transformers==4.43.1
9
+ einops
10
+ sentencepiece