File size: 3,785 Bytes
5caedb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import pathlib
import random
import unittest
from unittest.mock import MagicMock

import pandas as pd
import pytest

from llm_studio.app_utils.default_datasets import (
    prepare_default_dataset_causal_language_modeling,
)
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.utils.data_utils import load_train_valid_data


@pytest.fixture
def cfg_mock():
    cfg = MagicMock()
    cfg.dataset.train_dataframe = "/path/to/train/data"
    cfg.dataset.validation_dataframe = "/path/to/validation/data"

    cfg.dataset.system_column = "None"
    cfg.dataset.prompt_column = "prompt"
    cfg.dataset.answer_column = "answer"

    cfg.dataset.validation_size = 0.2
    return cfg


@pytest.fixture
def read_dataframe_drop_missing_labels_mock(monkeypatch):
    data = {
        "prompt": [f"Prompt{i}" for i in range(100)],
        "answer": [f"Answer{i}" for i in range(100)],
        "id": list(range(100)),
    }
    df = pd.DataFrame(data)
    mock = MagicMock(return_value=df)
    monkeypatch.setattr(
        "llm_studio.src.utils.data_utils.read_dataframe_drop_missing_labels", mock
    )
    return mock


numbers = list(range(100))
random.shuffle(
    numbers,
)
groups = [numbers[n::13] for n in range(13)]


@pytest.fixture
def conversation_chain_ids_mock(monkeypatch):
    def mocked_init(self, *args, **kwargs):
        self.conversation_chain_ids = groups

    with unittest.mock.patch.object(
        ConversationChainHandler, "__init__", new=mocked_init
    ):
        yield


def test_get_data_custom_validation_strategy(
    cfg_mock, read_dataframe_drop_missing_labels_mock
):
    cfg_mock.dataset.validation_strategy = "custom"
    train_df, val_df = load_train_valid_data(cfg_mock)
    assert len(train_df), len(val_df) == 100


def test_get_data_automatic_split(
    cfg_mock, read_dataframe_drop_missing_labels_mock, conversation_chain_ids_mock
):
    cfg_mock.dataset.validation_strategy = "automatic"
    train_df, val_df = load_train_valid_data(cfg_mock)
    train_ids = set(train_df["id"].tolist())
    val_ids = set(val_df["id"].tolist())

    assert len(train_ids.intersection(val_ids)) == 0
    assert len(train_ids) + len(val_ids) == 100

    shared_groups = [
        i for i in groups if not train_ids.isdisjoint(i) and not val_ids.isdisjoint(i)
    ]
    assert len(shared_groups) == 0


def test_oasst_data_automatic_split(tmp_path: pathlib.Path):
    prepare_default_dataset_causal_language_modeling(tmp_path)
    assert len(os.listdir(tmp_path)) > 0, tmp_path
    cfg_mock = MagicMock()
    for file in os.listdir(tmp_path):
        if file.endswith(".pq"):
            cfg_mock.dataset.train_dataframe = os.path.join(tmp_path, file)

            cfg_mock.dataset.system_column = "None"
            cfg_mock.dataset.prompt_column = ("instruction",)
            cfg_mock.dataset.answer_column = "output"
            cfg_mock.dataset.parent_id_column = "parent_id"
            cfg_mock.dataset.id_column = "id"
            cfg_mock.dataset.prompt_column_separator = "\n\n"

            cfg_mock.dataset.validation_strategy = "automatic"

            for validation_size in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]:
                cfg_mock.dataset.validation_size = validation_size

                train_df, val_df = load_train_valid_data(cfg_mock)
                assert set(train_df["parent_id"].dropna().values).isdisjoint(
                    set(val_df["id"].dropna().values)
                )
                assert set(val_df["parent_id"].dropna().values).isdisjoint(
                    set(train_df["id"].dropna().values)
                )
                assert (len(val_df) / (len(train_df) + len(val_df))) == pytest.approx(
                    validation_size, 0.05
                )