Feat: Add sharegpt multirole (#1137)
Browse files* feat(prompt): support multiple roles for sharegpt
* fix: add handling of empty role back
* feat: rebased and allowed more dynamic roles via config
* fix: variable
* chore: update message
* feat: add vicuna format
* fix: JSON serializable error
* fix: typing
* fix: don't remap for unknown keys
* fix: add roles to pydantic
* feat: add test
* chore: remove leftover print
* chore: remove leftover comment
* chore: remove print
* fix: update test to use chatml
README.md
CHANGED
|
@@ -651,9 +651,13 @@ datasets:
|
|
| 651 |
train_on_split: train # Optional[str] name of dataset split to load from
|
| 652 |
|
| 653 |
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
| 654 |
-
conversation:
|
| 655 |
field_human: # Optional[str]. Human key to use for conversation.
|
| 656 |
field_model: # Optional[str]. Assistant key to use for conversation.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
# Custom user instruction prompt
|
| 659 |
- path: repo
|
|
|
|
| 651 |
train_on_split: train # Optional[str] name of dataset split to load from
|
| 652 |
|
| 653 |
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
| 654 |
+
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
| 655 |
field_human: # Optional[str]. Human key to use for conversation.
|
| 656 |
field_model: # Optional[str]. Assistant key to use for conversation.
|
| 657 |
+
# Add additional keys from your dataset as input or output roles
|
| 658 |
+
roles:
|
| 659 |
+
input: # Optional[List[str]]. These will be masked based on train_on_input
|
| 660 |
+
output: # Optional[List[str]].
|
| 661 |
|
| 662 |
# Custom user instruction prompt
|
| 663 |
- path: repo
|
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
| 2 |
|
|
|
|
| 3 |
from typing import Any, Dict, Optional
|
| 4 |
|
| 5 |
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
|
@@ -11,6 +12,8 @@ from axolotl.utils.tokenization import (
|
|
| 11 |
merge_consecutive_messages,
|
| 12 |
)
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def register_chatml_template(system_message=None):
|
| 16 |
system_message = system_message or "You are a helpful assistant."
|
|
@@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
| 42 |
)
|
| 43 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
| 44 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
|
|
|
| 45 |
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
| 46 |
ShareGPTPrompterV2(
|
| 47 |
conversation=conversation,
|
| 48 |
role_key_model=field_model,
|
| 49 |
role_key_human=field_human,
|
|
|
|
| 50 |
),
|
| 51 |
tokenizer,
|
| 52 |
cfg.train_on_inputs,
|
|
@@ -142,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
| 142 |
"system": "system",
|
| 143 |
}
|
| 144 |
turns = [
|
| 145 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
for t in conversations
|
| 147 |
]
|
| 148 |
return turns
|
|
|
|
| 1 |
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
| 2 |
|
| 3 |
+
import logging
|
| 4 |
from typing import Any, Dict, Optional
|
| 5 |
|
| 6 |
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
|
|
|
| 12 |
merge_consecutive_messages,
|
| 13 |
)
|
| 14 |
|
| 15 |
+
LOG = logging.getLogger("axolotl")
|
| 16 |
+
|
| 17 |
|
| 18 |
def register_chatml_template(system_message=None):
|
| 19 |
system_message = system_message or "You are a helpful assistant."
|
|
|
|
| 45 |
)
|
| 46 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
| 47 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
| 48 |
+
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
| 49 |
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
| 50 |
ShareGPTPrompterV2(
|
| 51 |
conversation=conversation,
|
| 52 |
role_key_model=field_model,
|
| 53 |
role_key_human=field_human,
|
| 54 |
+
roles=roles,
|
| 55 |
),
|
| 56 |
tokenizer,
|
| 57 |
cfg.train_on_inputs,
|
|
|
|
| 147 |
"system": "system",
|
| 148 |
}
|
| 149 |
turns = [
|
| 150 |
+
{
|
| 151 |
+
"from": (
|
| 152 |
+
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
| 153 |
+
),
|
| 154 |
+
"value": t[value_key],
|
| 155 |
+
}
|
| 156 |
for t in conversations
|
| 157 |
]
|
| 158 |
return turns
|
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
|
|
| 11 |
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
| 12 |
add_get_turns_to_conversation,
|
| 13 |
)
|
| 14 |
-
from axolotl.prompters import IGNORE_TOKEN_ID
|
| 15 |
|
| 16 |
LOG = logging.getLogger("axolotl")
|
| 17 |
|
|
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
| 37 |
|
| 38 |
def __init__(
|
| 39 |
self,
|
| 40 |
-
prompter,
|
| 41 |
tokenizer,
|
| 42 |
train_on_inputs: bool = False,
|
| 43 |
sequence_len: int = 2048,
|
|
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 340 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
| 341 |
)
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
| 344 |
role_remap = []
|
| 345 |
if (
|
|
@@ -360,19 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 360 |
LOG.warning(f"expected tuple, got {part}")
|
| 361 |
continue
|
| 362 |
|
| 363 |
-
tool_role_label = None
|
| 364 |
-
if len(conversation.roles) == 3:
|
| 365 |
-
(
|
| 366 |
-
user_role_label,
|
| 367 |
-
assistant_role_label,
|
| 368 |
-
tool_role_label,
|
| 369 |
-
) = conversation.roles
|
| 370 |
-
else:
|
| 371 |
-
user_role_label, assistant_role_label = conversation.roles
|
| 372 |
role, content = part
|
| 373 |
|
| 374 |
# Uses "in" because role contains extra characters
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
role = (
|
| 377 |
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
| 378 |
if role_remap
|
|
@@ -392,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 392 |
else:
|
| 393 |
# everything from this is masked out from the labels
|
| 394 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 395 |
-
elif
|
| 396 |
role = (
|
| 397 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
| 398 |
if role_remap
|
|
@@ -423,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 423 |
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
| 424 |
len_role, len(labels)
|
| 425 |
)
|
| 426 |
-
elif
|
| 427 |
turn = content
|
| 428 |
# this is only ever the first part, should include the bos token and the user query
|
| 429 |
res = self._tokenize(
|
|
@@ -434,11 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 434 |
else:
|
| 435 |
# everything from this is masked out from the labels
|
| 436 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 437 |
-
elif tool_role_label and tool_role_label in role:
|
| 438 |
-
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 439 |
-
else:
|
| 440 |
-
LOG.warning(f"unhandled role: {role}")
|
| 441 |
-
continue
|
| 442 |
|
| 443 |
# pylint: disable=duplicate-code
|
| 444 |
result, current_len = parse_tokenized_to_result(
|
|
|
|
| 11 |
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
| 12 |
add_get_turns_to_conversation,
|
| 13 |
)
|
| 14 |
+
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
| 15 |
|
| 16 |
LOG = logging.getLogger("axolotl")
|
| 17 |
|
|
|
|
| 37 |
|
| 38 |
def __init__(
|
| 39 |
self,
|
| 40 |
+
prompter: Prompter,
|
| 41 |
tokenizer,
|
| 42 |
train_on_inputs: bool = False,
|
| 43 |
sequence_len: int = 2048,
|
|
|
|
| 340 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
| 341 |
)
|
| 342 |
|
| 343 |
+
input_roles = {conversation.roles[0]}
|
| 344 |
+
output_roles = {conversation.roles[1]}
|
| 345 |
+
|
| 346 |
+
if len(conversation.roles) == 3:
|
| 347 |
+
tool_role_label = conversation.roles[2]
|
| 348 |
+
input_roles.add(tool_role_label)
|
| 349 |
+
|
| 350 |
+
# Add roles from the config
|
| 351 |
+
if self.prompter.roles:
|
| 352 |
+
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
| 353 |
+
for role in self.prompter.roles["input"]:
|
| 354 |
+
input_roles.add(role)
|
| 355 |
+
|
| 356 |
+
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
| 357 |
+
for role in self.prompter.roles["output"]:
|
| 358 |
+
output_roles.add(role)
|
| 359 |
+
|
| 360 |
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
| 361 |
role_remap = []
|
| 362 |
if (
|
|
|
|
| 377 |
LOG.warning(f"expected tuple, got {part}")
|
| 378 |
continue
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
role, content = part
|
| 381 |
|
| 382 |
# Uses "in" because role contains extra characters
|
| 383 |
+
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
| 384 |
+
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
| 385 |
+
empty_role = role.strip() == ""
|
| 386 |
+
|
| 387 |
+
if not any([input_turn, output_turn, empty_role]):
|
| 388 |
+
LOG.warning(f"unhandled role: {role}")
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
if input_turn:
|
| 392 |
role = (
|
| 393 |
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
| 394 |
if role_remap
|
|
|
|
| 408 |
else:
|
| 409 |
# everything from this is masked out from the labels
|
| 410 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 411 |
+
elif output_turn:
|
| 412 |
role = (
|
| 413 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
| 414 |
if role_remap
|
|
|
|
| 439 |
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
| 440 |
len_role, len(labels)
|
| 441 |
)
|
| 442 |
+
elif empty_role:
|
| 443 |
turn = content
|
| 444 |
# this is only ever the first part, should include the bos token and the user query
|
| 445 |
res = self._tokenize(
|
|
|
|
| 450 |
else:
|
| 451 |
# everything from this is masked out from the labels
|
| 452 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
# pylint: disable=duplicate-code
|
| 455 |
result, current_len = parse_tokenized_to_result(
|
src/axolotl/prompters.py
CHANGED
|
@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|
| 259 |
"Role did not alternate between turns (gpt and human). Please check your data."
|
| 260 |
)
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
| 264 |
"""
|
|
@@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
| 268 |
role_key_human = "human"
|
| 269 |
role_key_model = "gpt"
|
| 270 |
# Optional, only used for tool usage datasets.
|
| 271 |
-
role_key_tool = None
|
|
|
|
|
|
|
| 272 |
|
| 273 |
def __init__(
|
| 274 |
self,
|
|
@@ -277,6 +285,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
| 277 |
role_key_human: Optional[str] = None,
|
| 278 |
role_key_model: Optional[str] = None,
|
| 279 |
role_key_tool: Optional[str] = None,
|
|
|
|
| 280 |
):
|
| 281 |
if conversation:
|
| 282 |
if isinstance(conversation, Conversation):
|
|
@@ -291,6 +300,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
| 291 |
self.role_key_model = role_key_model
|
| 292 |
if role_key_tool:
|
| 293 |
self.role_key_tool = role_key_tool
|
|
|
|
|
|
|
| 294 |
|
| 295 |
def _build_result(self, source):
|
| 296 |
if len(source) < 2:
|
|
@@ -322,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
| 322 |
|
| 323 |
conv.messages = []
|
| 324 |
for _, sentence in enumerate(source):
|
| 325 |
-
|
| 326 |
-
if
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
|
|
|
| 330 |
conv.append_message(role, sentence["value"])
|
| 331 |
|
| 332 |
return conv.get_turns()
|
|
@@ -354,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|
| 354 |
conversation: Optional[Union[str, Conversation]] = None,
|
| 355 |
role_key_human: Optional[str] = None,
|
| 356 |
role_key_model: Optional[str] = None,
|
|
|
|
| 357 |
):
|
| 358 |
super().__init__(
|
| 359 |
conversation=conversation,
|
| 360 |
role_key_human=role_key_human,
|
| 361 |
role_key_model=role_key_model,
|
|
|
|
| 362 |
)
|
| 363 |
|
| 364 |
|
|
|
|
| 259 |
"Role did not alternate between turns (gpt and human). Please check your data."
|
| 260 |
)
|
| 261 |
|
| 262 |
+
CONVERSATION_ROLE_FORMAT = {
|
| 263 |
+
"chatml": "<|im_start|>{ROLE}",
|
| 264 |
+
"zephyr": "<|{ROLE}|>",
|
| 265 |
+
"vicuna_v1.1": "{ROLE}",
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
|
| 269 |
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
| 270 |
"""
|
|
|
|
| 274 |
role_key_human = "human"
|
| 275 |
role_key_model = "gpt"
|
| 276 |
# Optional, only used for tool usage datasets.
|
| 277 |
+
role_key_tool: Optional[str] = None
|
| 278 |
+
# Optional, role input/output mapping
|
| 279 |
+
roles: Optional[dict] = None
|
| 280 |
|
| 281 |
def __init__(
|
| 282 |
self,
|
|
|
|
| 285 |
role_key_human: Optional[str] = None,
|
| 286 |
role_key_model: Optional[str] = None,
|
| 287 |
role_key_tool: Optional[str] = None,
|
| 288 |
+
roles: Optional[dict] = None,
|
| 289 |
):
|
| 290 |
if conversation:
|
| 291 |
if isinstance(conversation, Conversation):
|
|
|
|
| 300 |
self.role_key_model = role_key_model
|
| 301 |
if role_key_tool:
|
| 302 |
self.role_key_tool = role_key_tool
|
| 303 |
+
if roles:
|
| 304 |
+
self.roles = roles
|
| 305 |
|
| 306 |
def _build_result(self, source):
|
| 307 |
if len(source) < 2:
|
|
|
|
| 333 |
|
| 334 |
conv.messages = []
|
| 335 |
for _, sentence in enumerate(source):
|
| 336 |
+
from_role = sentence["from"]
|
| 337 |
+
if from_role in roles:
|
| 338 |
+
role = roles[from_role]
|
| 339 |
+
else:
|
| 340 |
+
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
| 341 |
+
raise NotImplementedError(
|
| 342 |
+
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
| 343 |
+
"Please help us by creating an Issue to add support for this conversation type."
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
| 347 |
+
ROLE=from_role
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
| 351 |
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
| 352 |
+
|
| 353 |
conv.append_message(role, sentence["value"])
|
| 354 |
|
| 355 |
return conv.get_turns()
|
|
|
|
| 377 |
conversation: Optional[Union[str, Conversation]] = None,
|
| 378 |
role_key_human: Optional[str] = None,
|
| 379 |
role_key_model: Optional[str] = None,
|
| 380 |
+
roles: Optional[dict] = None,
|
| 381 |
):
|
| 382 |
super().__init__(
|
| 383 |
conversation=conversation,
|
| 384 |
role_key_human=role_key_human,
|
| 385 |
role_key_model=role_key_model,
|
| 386 |
+
roles=roles,
|
| 387 |
)
|
| 388 |
|
| 389 |
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
|
|
| 96 |
field_human: Optional[str] = None
|
| 97 |
field_model: Optional[str] = None
|
| 98 |
|
|
|
|
|
|
|
| 99 |
|
| 100 |
class UserDefinedDPOType(BaseModel):
|
| 101 |
"""User defined typing for DPO"""
|
|
|
|
| 96 |
field_human: Optional[str] = None
|
| 97 |
field_model: Optional[str] = None
|
| 98 |
|
| 99 |
+
roles: Optional[Dict[str, List[str]]] = None
|
| 100 |
+
|
| 101 |
|
| 102 |
class UserDefinedDPOType(BaseModel):
|
| 103 |
"""User defined typing for DPO"""
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
|
@@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
|
|
| 62 |
)
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
@pytest.fixture(name="tokenizer")
|
| 66 |
def fixture_tokenizer():
|
| 67 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
|
@@ -196,3 +228,39 @@ class TestSharegpt:
|
|
| 196 |
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
| 197 |
]
|
| 198 |
# fmt: on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
)
|
| 63 |
|
| 64 |
|
| 65 |
+
@pytest.fixture(name="multi_role_dataset")
|
| 66 |
+
def fixture_multi_role_dataset():
|
| 67 |
+
return Dataset.from_list(
|
| 68 |
+
[
|
| 69 |
+
{
|
| 70 |
+
"conversations": [
|
| 71 |
+
{
|
| 72 |
+
"from": "system",
|
| 73 |
+
"value": "use get_weather(city) to get the weather for a city",
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"from": "human",
|
| 77 |
+
"value": "hello, what's the weather in New York?",
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"from": "gpt",
|
| 81 |
+
"value": "let me get that for you",
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"from": "tool",
|
| 85 |
+
"value": "get_weather(New York)",
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"from": "gpt",
|
| 89 |
+
"value": "the weather in New York is 70 degrees and sunny",
|
| 90 |
+
},
|
| 91 |
+
]
|
| 92 |
+
}
|
| 93 |
+
]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
@pytest.fixture(name="tokenizer")
|
| 98 |
def fixture_tokenizer():
|
| 99 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
|
|
|
| 228 |
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
| 229 |
]
|
| 230 |
# fmt: on
|
| 231 |
+
|
| 232 |
+
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
| 233 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
| 234 |
+
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
| 235 |
+
tokenizer,
|
| 236 |
+
False, # train_on_inputs
|
| 237 |
+
2048, # sequence_len
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
dataset_wrapper = TokenizedPromptDataset(
|
| 241 |
+
strategy, multi_role_dataset, process_count=1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
input_ids = dataset_wrapper[0]["input_ids"]
|
| 245 |
+
# fmt: off
|
| 246 |
+
assert input_ids == [
|
| 247 |
+
1, # bos
|
| 248 |
+
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
| 249 |
+
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
| 250 |
+
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
| 251 |
+
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
| 252 |
+
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
| 253 |
+
]
|
| 254 |
+
# fmt: on
|
| 255 |
+
|
| 256 |
+
labels = dataset_wrapper[0]["labels"]
|
| 257 |
+
# fmt: off
|
| 258 |
+
assert labels == [
|
| 259 |
+
-100, # bos
|
| 260 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
| 261 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
| 262 |
+
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
| 263 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
| 264 |
+
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
| 265 |
+
]
|
| 266 |
+
# fmt: on
|