File size: 5,161 Bytes
e6b57de
 
5159d00
b7d8a7d
 
37293dc
e6b57de
5159d00
553a86b
 
2bc1a5b
cc5d31e
 
 
 
 
 
 
5159d00
48434be
cc5d31e
 
 
 
5159d00
 
48434be
5159d00
 
 
 
 
 
 
31b9e0c
5159d00
 
2bc1a5b
48434be
31b9e0c
5159d00
 
 
e50a64e
 
553a86b
3a38271
 
b7d8a7d
 
cc5d31e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7d8a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Module for tokenization utilities"""

import logging
import re
from typing import Dict, List

from termcolor import colored

LOG = logging.getLogger("axolotl")


def check_dataset_labels(
    dataset,
    tokenizer,
    num_examples=5,
    text_only=False,
    rl_mode=False,
):
    # the dataset is already shuffled, so let's just check the first 5 elements
    for idx in range(num_examples):
        if not rl_mode:
            check_example_labels(dataset[idx], tokenizer, text_only=text_only)
        else:
            check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)


def check_example_labels(example, tokenizer, text_only=False):
    # Get the input_ids, labels, and attention_mask from the dataset
    input_ids = example["input_ids"]
    labels = example["labels"]

    # You can compare the input_ids and labels element-wise
    # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
    colored_tokens = []
    for _, (input_id, label_id) in enumerate(zip(input_ids, labels)):
        decoded_input_token = tokenizer.decode(input_id)
        # Choose the color based on whether the label has the ignore value or not
        color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
        colored_token = colored(decoded_input_token, color) + (
            not text_only and colored(f"({label_id}, {input_id})", "white") or ""
        )
        colored_tokens.append(colored_token)

    delimiter = "" if text_only else " "
    LOG.info(delimiter.join(colored_tokens))
    LOG.info("\n\n\n")

    return " ".join(colored_tokens)


def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
    """Helper function to color tokens based on their type."""
    colored_text = colored(decoded_token, color)
    return (
        colored_text
        if text_only
        else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
    )


def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
    """Helper function to process and color tokens."""
    colored_tokens = [
        color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
        for token in tokenizer.encode(tokens)
    ]
    return colored_tokens


def check_rl_example_labels(example, tokenizer, text_only=False):
    field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"

    input_tokens = example[field_prompt]
    labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]

    # Process and color each type of token
    colored_tokens = process_tokens_for_rl_debug(
        input_tokens, "yellow", tokenizer, text_only
    )
    colored_chosens = process_tokens_for_rl_debug(
        labels_chosen, "green", tokenizer, text_only
    )
    colored_rejecteds = process_tokens_for_rl_debug(
        labels_rejected, "red", tokenizer, text_only
    )

    # Create a delimiter based on text_only flag
    delimiter = "" if text_only else " "

    # Logging information
    LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
    LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
    LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")

    return delimiter.join(colored_tokens)


GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
    "SYSTEM": "system",
    "USER": "human",
    "ASSISTANT": "gpt",
    "FUNCTION RESPONSE": "tool",
}

GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")


def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
    """
    Converts a ChatML formatted row to a list of messages in ShareGPT format.
    Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
    """

    system_prompt = row.get("system")
    if system_prompt:
        system_prompt = system_prompt.removeprefix("SYSTEM: ")

    chat_str = row["chat"]
    chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]

    chat_msg_dicts = [
        {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
        for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
    ]

    if system_prompt:
        chat_msg_dicts = [
            {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
        ] + chat_msg_dicts

    return chat_msg_dicts


def merge_consecutive_messages(messages):
    """
    Merge consecutive messages from the same sender into a single message.
    This can be useful with datasets that contain multiple consecutive tool calls.
    """

    merged_messages = []
    current_from = None
    current_message = ""

    for msg in messages:
        if current_from == msg["from"]:
            current_message += msg["value"]
        else:
            if current_from is not None:
                merged_messages.append({"from": current_from, "value": current_message})
            current_from = msg["from"]
            current_message = msg["value"]

    if current_from is not None:
        merged_messages.append({"from": current_from, "value": current_message})

    return merged_messages