winglian commited on
Commit
cb78a36
·
unverified ·
1 Parent(s): 8b9c15b

improve tool handling roles (#1587)

Browse files
src/axolotl/prompt_strategies/sharegpt.py CHANGED
@@ -1,7 +1,7 @@
1
  """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
2
 
3
  import logging
4
- from typing import Any, Dict, Optional
5
 
6
  from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
7
 
@@ -39,76 +39,40 @@ def register_chatml_template(system_message=None):
39
  )
40
 
41
 
42
- def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
43
- conversation = (
44
- ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
45
- )
46
- field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
47
- field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
48
- roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
49
- strategy = SimpleShareGPTPromptTokenizingStrategy(
50
- ShareGPTPrompterV2(
51
- conversation=conversation,
52
- role_key_model=field_model,
53
- role_key_human=field_human,
54
- roles=roles,
55
- ),
56
- tokenizer,
57
- cfg.train_on_inputs,
58
- cfg.sequence_len,
59
- )
60
- if ds_cfg and "strict" in ds_cfg:
61
- strategy.strict = ds_cfg["strict"]
62
- return strategy
63
-
64
-
65
- def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
66
- conversation = (
67
- ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
68
- )
69
- strategy = UltrachatShareGPTPromptTokenizingStrategy(
70
- ShareGPTPrompterV2(
71
- conversation=conversation,
72
- ),
73
- tokenizer,
74
- cfg.train_on_inputs,
75
- cfg.sequence_len,
76
- )
77
- if ds_cfg and "strict" in ds_cfg:
78
- strategy.strict = ds_cfg["strict"]
79
- return strategy
80
-
81
-
82
- def load_role(tokenizer, cfg):
83
- return SimpleRoleShareGPTPromptTokenizingStrategy(
84
- ShareGPTPrompterV2(),
85
- tokenizer,
86
- cfg.train_on_inputs,
87
- cfg.sequence_len,
88
- )
89
-
90
-
91
- def load_guanaco(tokenizer, cfg):
92
- return GuanacoShareGPTPromptTokenizingStrategy(
93
- ShareGPTPrompterV2(),
94
- tokenizer,
95
- cfg.train_on_inputs,
96
- cfg.sequence_len,
97
- )
98
 
99
-
100
- def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
101
- conversation = (
102
- ds_cfg["conversation"]
103
- if ds_cfg and "conversation" in ds_cfg
104
- else "chatml_glaive"
105
- )
106
- return GlaiveShareGPTPromptTokenizingStrategy(
107
- ShareGPTPrompterV2(conversation=conversation),
108
- tokenizer,
109
- cfg.train_on_inputs,
110
- cfg.sequence_len,
111
- )
112
 
113
 
114
  class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -158,7 +122,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
158
  return turns
159
 
160
 
161
- class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
 
 
162
  """
163
  basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
164
  """
@@ -209,3 +175,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
209
  conversation = merge_consecutive_messages(conversation)
210
 
211
  return conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
2
 
3
  import logging
4
+ from typing import Any, Dict, Optional, Type
5
 
6
  from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
7
 
 
39
  )
40
 
41
 
42
+ def build_loader(
43
+ tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
44
+ prompter_cls: Type["ShareGPTPrompterV2"],
45
+ default_conversation: Optional[str] = None,
46
+ ):
47
+ def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
48
+ conversation = (
49
+ ds_cfg["conversation"]
50
+ if ds_cfg and "conversation" in ds_cfg
51
+ else default_conversation
52
+ )
53
+ field_human = (
54
+ ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
55
+ )
56
+ field_model = (
57
+ ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
58
+ )
59
+ roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
60
+ strategy = tokenization_strategy_cls(
61
+ prompter_cls(
62
+ conversation=conversation,
63
+ role_key_model=field_model,
64
+ role_key_human=field_human,
65
+ roles=roles,
66
+ ),
67
+ tokenizer,
68
+ cfg.train_on_inputs,
69
+ cfg.sequence_len,
70
+ )
71
+ if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
72
+ strategy.strict = ds_cfg["strict"]
73
+ return strategy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ return _load
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
 
122
  return turns
123
 
124
 
125
+ class SimpleRoleShareGPTPromptTokenizingStrategy(
126
+ SimpleShareGPTPromptTokenizingStrategy
127
+ ):
128
  """
129
  basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
130
  """
 
175
  conversation = merge_consecutive_messages(conversation)
176
 
177
  return conversation
178
+
179
+
180
+ load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
181
+ load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
182
+ load_ultrachat = build_loader(
183
+ UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
184
+ )
185
+ load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
186
+ load_glaive = build_loader(
187
+ GlaiveShareGPTPromptTokenizingStrategy,
188
+ ShareGPTPrompterV2,
189
+ default_conversation="chatml_glaive",
190
+ )
src/axolotl/prompters.py CHANGED
@@ -348,7 +348,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
348
  )
349
 
350
  if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
351
- LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
 
 
 
352
 
353
  conv.append_message(role, sentence["value"])
354
 
 
348
  )
349
 
350
  if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
351
+ if (
352
+ role != "assistant"
353
+ ): # back to back assistant calls may be okay for tool calls
354
+ LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
355
 
356
  conv.append_message(role, sentence["value"])
357