Spaces:
Sleeping
Sleeping
import os | |
import pathlib | |
import random | |
import unittest | |
from unittest.mock import MagicMock | |
import pandas as pd | |
import pytest | |
from llm_studio.app_utils.default_datasets import ( | |
prepare_default_dataset_causal_language_modeling, | |
) | |
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler | |
from llm_studio.src.utils.data_utils import load_train_valid_data | |
def cfg_mock(): | |
cfg = MagicMock() | |
cfg.dataset.train_dataframe = "/path/to/train/data" | |
cfg.dataset.validation_dataframe = "/path/to/validation/data" | |
cfg.dataset.system_column = "None" | |
cfg.dataset.prompt_column = "prompt" | |
cfg.dataset.answer_column = "answer" | |
cfg.dataset.validation_size = 0.2 | |
return cfg | |
def read_dataframe_drop_missing_labels_mock(monkeypatch): | |
data = { | |
"prompt": [f"Prompt{i}" for i in range(100)], | |
"answer": [f"Answer{i}" for i in range(100)], | |
"id": list(range(100)), | |
} | |
df = pd.DataFrame(data) | |
mock = MagicMock(return_value=df) | |
monkeypatch.setattr( | |
"llm_studio.src.utils.data_utils.read_dataframe_drop_missing_labels", mock | |
) | |
return mock | |
numbers = list(range(100)) | |
random.shuffle( | |
numbers, | |
) | |
groups = [numbers[n::13] for n in range(13)] | |
def conversation_chain_ids_mock(monkeypatch): | |
def mocked_init(self, *args, **kwargs): | |
self.conversation_chain_ids = groups | |
with unittest.mock.patch.object( | |
ConversationChainHandler, "__init__", new=mocked_init | |
): | |
yield | |
def test_get_data_custom_validation_strategy( | |
cfg_mock, read_dataframe_drop_missing_labels_mock | |
): | |
cfg_mock.dataset.validation_strategy = "custom" | |
train_df, val_df = load_train_valid_data(cfg_mock) | |
assert len(train_df), len(val_df) == 100 | |
def test_get_data_automatic_split( | |
cfg_mock, read_dataframe_drop_missing_labels_mock, conversation_chain_ids_mock | |
): | |
cfg_mock.dataset.validation_strategy = "automatic" | |
train_df, val_df = load_train_valid_data(cfg_mock) | |
train_ids = set(train_df["id"].tolist()) | |
val_ids = set(val_df["id"].tolist()) | |
assert len(train_ids.intersection(val_ids)) == 0 | |
assert len(train_ids) + len(val_ids) == 100 | |
shared_groups = [ | |
i for i in groups if not train_ids.isdisjoint(i) and not val_ids.isdisjoint(i) | |
] | |
assert len(shared_groups) == 0 | |
def test_oasst_data_automatic_split(tmp_path: pathlib.Path): | |
prepare_default_dataset_causal_language_modeling(tmp_path) | |
assert len(os.listdir(tmp_path)) > 0, tmp_path | |
cfg_mock = MagicMock() | |
for file in os.listdir(tmp_path): | |
if file.endswith(".pq"): | |
cfg_mock.dataset.train_dataframe = os.path.join(tmp_path, file) | |
cfg_mock.dataset.system_column = "None" | |
cfg_mock.dataset.prompt_column = ("instruction",) | |
cfg_mock.dataset.answer_column = "output" | |
cfg_mock.dataset.parent_id_column = "parent_id" | |
cfg_mock.dataset.id_column = "id" | |
cfg_mock.dataset.prompt_column_separator = "\n\n" | |
cfg_mock.dataset.validation_strategy = "automatic" | |
for validation_size in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]: | |
cfg_mock.dataset.validation_size = validation_size | |
train_df, val_df = load_train_valid_data(cfg_mock) | |
assert set(train_df["parent_id"].dropna().values).isdisjoint( | |
set(val_df["id"].dropna().values) | |
) | |
assert set(val_df["parent_id"].dropna().values).isdisjoint( | |
set(train_df["id"].dropna().values) | |
) | |
assert (len(val_df) / (len(train_df) + len(val_df))) == pytest.approx( | |
validation_size, 0.05 | |
) | |