Spaces:
Sleeping
Sleeping
File size: 3,785 Bytes
5caedb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import pathlib
import random
import unittest
from unittest.mock import MagicMock
import pandas as pd
import pytest
from llm_studio.app_utils.default_datasets import (
prepare_default_dataset_causal_language_modeling,
)
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.utils.data_utils import load_train_valid_data
@pytest.fixture
def cfg_mock():
cfg = MagicMock()
cfg.dataset.train_dataframe = "/path/to/train/data"
cfg.dataset.validation_dataframe = "/path/to/validation/data"
cfg.dataset.system_column = "None"
cfg.dataset.prompt_column = "prompt"
cfg.dataset.answer_column = "answer"
cfg.dataset.validation_size = 0.2
return cfg
@pytest.fixture
def read_dataframe_drop_missing_labels_mock(monkeypatch):
data = {
"prompt": [f"Prompt{i}" for i in range(100)],
"answer": [f"Answer{i}" for i in range(100)],
"id": list(range(100)),
}
df = pd.DataFrame(data)
mock = MagicMock(return_value=df)
monkeypatch.setattr(
"llm_studio.src.utils.data_utils.read_dataframe_drop_missing_labels", mock
)
return mock
numbers = list(range(100))
random.shuffle(
numbers,
)
groups = [numbers[n::13] for n in range(13)]
@pytest.fixture
def conversation_chain_ids_mock(monkeypatch):
def mocked_init(self, *args, **kwargs):
self.conversation_chain_ids = groups
with unittest.mock.patch.object(
ConversationChainHandler, "__init__", new=mocked_init
):
yield
def test_get_data_custom_validation_strategy(
cfg_mock, read_dataframe_drop_missing_labels_mock
):
cfg_mock.dataset.validation_strategy = "custom"
train_df, val_df = load_train_valid_data(cfg_mock)
assert len(train_df), len(val_df) == 100
def test_get_data_automatic_split(
cfg_mock, read_dataframe_drop_missing_labels_mock, conversation_chain_ids_mock
):
cfg_mock.dataset.validation_strategy = "automatic"
train_df, val_df = load_train_valid_data(cfg_mock)
train_ids = set(train_df["id"].tolist())
val_ids = set(val_df["id"].tolist())
assert len(train_ids.intersection(val_ids)) == 0
assert len(train_ids) + len(val_ids) == 100
shared_groups = [
i for i in groups if not train_ids.isdisjoint(i) and not val_ids.isdisjoint(i)
]
assert len(shared_groups) == 0
def test_oasst_data_automatic_split(tmp_path: pathlib.Path):
prepare_default_dataset_causal_language_modeling(tmp_path)
assert len(os.listdir(tmp_path)) > 0, tmp_path
cfg_mock = MagicMock()
for file in os.listdir(tmp_path):
if file.endswith(".pq"):
cfg_mock.dataset.train_dataframe = os.path.join(tmp_path, file)
cfg_mock.dataset.system_column = "None"
cfg_mock.dataset.prompt_column = ("instruction",)
cfg_mock.dataset.answer_column = "output"
cfg_mock.dataset.parent_id_column = "parent_id"
cfg_mock.dataset.id_column = "id"
cfg_mock.dataset.prompt_column_separator = "\n\n"
cfg_mock.dataset.validation_strategy = "automatic"
for validation_size in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]:
cfg_mock.dataset.validation_size = validation_size
train_df, val_df = load_train_valid_data(cfg_mock)
assert set(train_df["parent_id"].dropna().values).isdisjoint(
set(val_df["id"].dropna().values)
)
assert set(val_df["parent_id"].dropna().values).isdisjoint(
set(train_df["id"].dropna().values)
)
assert (len(val_df) / (len(train_df) + len(val_df))) == pytest.approx(
validation_size, 0.05
)
|