davidberenstein1957 HF staff commited on
Commit
a38b23a
·
1 Parent(s): 7c45d36

Update error message duplicate mistakes

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +8 -4
  3. chat_interface_preference.py +87 -77
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦾💪🏽
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: true
10
  license: mit
 
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.39
8
  app_file: app.py
9
  pinned: true
10
  license: mit
app.py CHANGED
@@ -1,13 +1,17 @@
1
  #!/usr/bin/env python
2
  import os
3
  import random
4
- from threading import Thread
5
  from typing import Iterator
6
 
7
  import gradio as gr
8
  import spaces
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
11
 
12
  from chat_interface_preference import ChatInterface
13
 
@@ -118,7 +122,7 @@ chat_interface = ChatInterface(
118
  title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO) 🦾💪🏽",
119
  description="".join(
120
  [
121
- "This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) and [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) which allows for human feedback collection. ",
122
  "Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
123
  "This demo shows how you might capture human feedback directly from applications within Gradio. ",
124
  "The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
 
1
  #!/usr/bin/env python
2
  import os
3
  import random
4
+ from threading import Thread # noqa
5
  from typing import Iterator
6
 
7
  import gradio as gr
8
  import spaces
9
+ import torch # noqa
10
+ from transformers import (
11
+ AutoModelForCausalLM, # noqa
12
+ AutoTokenizer, # noqa
13
+ TextIteratorStreamer, # noqa
14
+ )
15
 
16
  from chat_interface_preference import ChatInterface
17
 
 
122
  title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO) 🦾💪🏽",
123
  description="".join(
124
  [
125
+ "This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) which also uses the [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) to allow for human feedback collection. ",
126
  "Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
127
  "This demo shows how you might capture human feedback directly from applications within Gradio. ",
128
  "The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
chat_interface_preference.py CHANGED
@@ -607,7 +607,7 @@ class ChatInterface(Blocks):
607
  if turn[-1]:
608
  conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
609
 
610
- return "<body>" + self.css + conversation + "</body>"
611
 
612
  def _get_conversation_in_openai_format(self, history):
613
  conversation = []
@@ -644,6 +644,7 @@ class ChatInterface(Blocks):
644
 
645
  @staticmethod
646
  def _check_if_two_responses(response):
 
647
  if response:
648
  matches = pattern.findall(response)
649
  return matches
@@ -683,30 +684,34 @@ class ChatInterface(Blocks):
683
 
684
  self._check_message(message)
685
  self._check_num_turns(history)
686
- _, response = history_with_input[-1]
 
 
 
687
  if self._check_if_two_responses(response):
688
- raise Error("Two options detected: undo, log or random pick continuation.")
 
 
 
689
 
690
- inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
 
 
 
 
 
691
 
692
- async def _get_response():
693
- if self.is_async:
694
- response = await self.fn(*inputs)
695
  else:
696
- response = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
697
- return response
698
 
699
- if n_generations == 1:
700
- response = await _get_response()
701
- else:
702
- response_one, response_two = await _get_response(), await _get_response()
703
- response = self._get_chat_message_comparison(response_one, response_two)
704
-
705
- if self.multimodal and isinstance(message, dict):
706
- self._append_multimodal_history(message, response, history)
707
- elif isinstance(message, str):
708
- history.append([message, response])
709
- return history, history
710
 
711
  async def _stream_fn(
712
  self,
@@ -723,67 +728,35 @@ class ChatInterface(Blocks):
723
  history = history_with_input[:-1]
724
  self._check_message(message)
725
  self._check_num_turns(history)
726
- _, response = history_with_input[-1]
727
- if self._check_if_two_responses(response):
728
- raise Error("Two options detected: undo, log or random pick continuation.")
729
-
730
- inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
731
 
732
- try:
733
- if self.is_async:
734
- generator = self.fn(*inputs)
735
- else:
736
- generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
737
- generator = SyncToAsyncIterator(generator, self.limiter)
738
- first_response = await async_iteration(generator)
739
- if n_generations == 2:
740
- first_response_formatted = self._get_chat_message_comparison(first_response, "")
741
- else:
742
- first_response_formatted = first_response
743
- if self.multimodal and isinstance(message, dict):
744
- for x in message["files"]:
745
- history.append([(x,), None])
746
- update = history + [[message["text"], first_response_formatted]]
747
- yield update, update
748
- else:
749
- update = history + [[message, first_response_formatted]]
750
- yield update, update
751
- except StopIteration:
752
- if self.multimodal and isinstance(message, dict):
753
- self._append_multimodal_history(message, None, history)
754
- yield history, history
755
- else:
756
- update = history + [[message, None]]
757
- yield update, update
758
- async for response in generator:
759
- if n_generations == 2:
760
- response_formatted = self._get_chat_message_comparison(response, "")
761
- else:
762
- response_formatted = response
763
- if self.multimodal and isinstance(message, dict):
764
- update = history + [[message["text"], response_formatted]]
765
- yield update, update
766
- else:
767
- update = history + [[message, response_formatted]]
768
- yield update, update
769
 
770
- if n_generations == 2:
771
- if self.is_async:
772
- generator_two = self.fn(*inputs)
773
- else:
774
- generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
775
- generator_two = SyncToAsyncIterator(generator_two, self.limiter)
776
  try:
777
- first_response_two = await async_iteration(generator_two)
778
- first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
 
 
 
 
 
 
 
 
779
  if self.multimodal and isinstance(message, dict):
780
  for x in message["files"]:
781
  history.append([(x,), None])
782
-
783
- update = history + [[message["text"], first_response_two_formatted]]
784
  yield update, update
785
  else:
786
- update = history + [[message, first_response_two_formatted]]
787
  yield update, update
788
  except StopIteration:
789
  if self.multimodal and isinstance(message, dict):
@@ -792,15 +765,52 @@ class ChatInterface(Blocks):
792
  else:
793
  update = history + [[message, None]]
794
  yield update, update
795
- async for response_two in generator_two:
796
- response_two = self._get_chat_message_comparison(response, response_two)
 
 
 
797
  if self.multimodal and isinstance(message, dict):
798
- update = history + [[message["text"], response_two]]
799
  yield update, update
800
  else:
801
- update = history + [[message, response_two]]
802
  yield update, update
803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
  async def _log_fn(
805
  self, message: str | dict[str, list], history: list[list[str | tuple | None]], log: str
806
  ) -> tuple[
 
607
  if turn[-1]:
608
  conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
609
 
610
+ return "<body>" + conversation + "</body>"
611
 
612
  def _get_conversation_in_openai_format(self, history):
613
  conversation = []
 
644
 
645
  @staticmethod
646
  def _check_if_two_responses(response):
647
+ print(response)
648
  if response:
649
  matches = pattern.findall(response)
650
  return matches
 
684
 
685
  self._check_message(message)
686
  self._check_num_turns(history)
687
+ if history:
688
+ _, response = history[-1]
689
+ else:
690
+ response = None
691
  if self._check_if_two_responses(response):
692
+ Info("Two options detected: provide preference, undo or clear to continue conversation.")
693
+ return history, history
694
+ else:
695
+ inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
696
 
697
+ async def _get_response():
698
+ if self.is_async:
699
+ response = await self.fn(*inputs)
700
+ else:
701
+ response = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
702
+ return response
703
 
704
+ if n_generations == 1:
705
+ response = await _get_response()
 
706
  else:
707
+ response_one, response_two = await _get_response(), await _get_response()
708
+ response = self._get_chat_message_comparison(response_one, response_two)
709
 
710
+ if self.multimodal and isinstance(message, dict):
711
+ self._append_multimodal_history(message, response, history)
712
+ elif isinstance(message, str):
713
+ history.append([message, response])
714
+ return history, history
 
 
 
 
 
 
715
 
716
  async def _stream_fn(
717
  self,
 
728
  history = history_with_input[:-1]
729
  self._check_message(message)
730
  self._check_num_turns(history)
 
 
 
 
 
731
 
732
+ if history:
733
+ _, response = history[-1]
734
+ else:
735
+ response = None
736
+ if self._check_if_two_responses(response):
737
+ Info("Two options detected: provide preference, undo or clear to continue conversation.")
738
+ yield history, history
739
+ else:
740
+ inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
 
 
 
 
 
 
 
742
  try:
743
+ if self.is_async:
744
+ generator = self.fn(*inputs)
745
+ else:
746
+ generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
747
+ generator = SyncToAsyncIterator(generator, self.limiter)
748
+ first_response = await async_iteration(generator)
749
+ if n_generations == 2:
750
+ first_response_formatted = self._get_chat_message_comparison(first_response, "")
751
+ else:
752
+ first_response_formatted = first_response
753
  if self.multimodal and isinstance(message, dict):
754
  for x in message["files"]:
755
  history.append([(x,), None])
756
+ update = history + [[message["text"], first_response_formatted]]
 
757
  yield update, update
758
  else:
759
+ update = history + [[message, first_response_formatted]]
760
  yield update, update
761
  except StopIteration:
762
  if self.multimodal and isinstance(message, dict):
 
765
  else:
766
  update = history + [[message, None]]
767
  yield update, update
768
+ async for response in generator:
769
+ if n_generations == 2:
770
+ response_formatted = self._get_chat_message_comparison(response, "")
771
+ else:
772
+ response_formatted = response
773
  if self.multimodal and isinstance(message, dict):
774
+ update = history + [[message["text"], response_formatted]]
775
  yield update, update
776
  else:
777
+ update = history + [[message, response_formatted]]
778
  yield update, update
779
 
780
+ if n_generations == 2:
781
+ if self.is_async:
782
+ generator_two = self.fn(*inputs)
783
+ else:
784
+ generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
785
+ generator_two = SyncToAsyncIterator(generator_two, self.limiter)
786
+ try:
787
+ first_response_two = await async_iteration(generator_two)
788
+ first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
789
+ if self.multimodal and isinstance(message, dict):
790
+ for x in message["files"]:
791
+ history.append([(x,), None])
792
+
793
+ update = history + [[message["text"], first_response_two_formatted]]
794
+ yield update, update
795
+ else:
796
+ update = history + [[message, first_response_two_formatted]]
797
+ yield update, update
798
+ except StopIteration:
799
+ if self.multimodal and isinstance(message, dict):
800
+ self._append_multimodal_history(message, None, history)
801
+ yield history, history
802
+ else:
803
+ update = history + [[message, None]]
804
+ yield update, update
805
+ async for response_two in generator_two:
806
+ response_two = self._get_chat_message_comparison(response, response_two)
807
+ if self.multimodal and isinstance(message, dict):
808
+ update = history + [[message["text"], response_two]]
809
+ yield update, update
810
+ else:
811
+ update = history + [[message, response_two]]
812
+ yield update, update
813
+
814
  async def _log_fn(
815
  self, message: str | dict[str, list], history: list[list[str | tuple | None]], log: str
816
  ) -> tuple[