File size: 2,004 Bytes
5caedb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging
from typing import Any, Dict

import numpy as np
import pandas as pd

from llm_studio.src.datasets.text_causal_language_modeling_ds import (
    CustomDataset as TextCausalLanguageModelingCustomDataset,
)
from llm_studio.src.utils.exceptions import LLMDataException

logger = logging.getLogger(__name__)


class CustomDataset(TextCausalLanguageModelingCustomDataset):
    def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
        super().__init__(df=df, cfg=cfg, mode=mode)
        self.answers_float = df[cfg.dataset.answer_column].astype(float).values

        if cfg.dataset.parent_id_column != "None":
            raise LLMDataException(
                "Parent ID column is not supported for regression datasets."
            )

    def __getitem__(self, idx: int) -> Dict:
        sample = super().__getitem__(idx)
        sample["class_label"] = self.answers_float[idx]
        return sample

    def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
        output["predictions"] = output["predictions"].float()
        preds = []
        for col in np.arange(len(cfg.dataset.answer_column)):
            preds.append(
                np.round(output["predictions"][:, col].cpu().numpy(), 3).astype(str)
            )
        preds = [",".join(pred) for pred in zip(*preds)]
        output["predicted_text"] = preds
        return super().postprocess_output(cfg, df, output)

    def clean_output(self, output, cfg):
        return output

    @classmethod
    def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):

        for answer_col in cfg.dataset.answer_column:
            assert answer_col in df.columns, (
                f"Answer column {answer_col} not found in the " f"{mode} DataFrame."
            )
            assert df.shape[0] == df[answer_col].dropna().shape[0], (
                f"The {mode} DataFrame"
                f" column {answer_col}"
                " contains missing values."
            )