llm-studio / llm_studio /src /datasets /text_causal_classification_ds.py
qinfeng722's picture
Upload 322 files
5caedb4 verified
import logging
from typing import Any, Dict
import numpy as np
import pandas as pd
import torch
from llm_studio.src.datasets.text_causal_language_modeling_ds import (
CustomDataset as TextCausalLanguageModelingCustomDataset,
)
from llm_studio.src.utils.exceptions import LLMDataException
logger = logging.getLogger(__name__)
class CustomDataset(TextCausalLanguageModelingCustomDataset):
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
super().__init__(df=df, cfg=cfg, mode=mode)
check_for_non_int_answers(cfg, df)
self.answers_int = df[cfg.dataset.answer_column].astype(int).values
max_value = np.max(self.answers_int)
min_value = np.min(self.answers_int)
if 1 < cfg.dataset.num_classes <= max_value:
raise LLMDataException(
"Number of classes is smaller than max label "
f"{max_value}. Please increase the setting accordingly."
)
elif cfg.dataset.num_classes == 1 and max_value > 1:
raise LLMDataException(
"For binary classification, max label should be 1 but is "
f"{max_value}."
)
if min_value < 0:
raise LLMDataException(
"Labels should be non-negative but min label is " f"{min_value}."
)
if min_value != 0 or max_value != np.unique(self.answers_int).size - 1:
logger.warning(
"Labels should start at 0 and be continuous but are "
f"{sorted(np.unique(self.answers_int))}."
)
if cfg.dataset.parent_id_column != "None":
raise LLMDataException(
"Parent ID column is not supported for classification datasets."
)
def __getitem__(self, idx: int) -> Dict:
sample = super().__getitem__(idx)
sample["class_label"] = self.answers_int[idx]
return sample
def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
output["logits"] = output["logits"].float()
if cfg.training.loss_function == "CrossEntropyLoss":
output["probabilities"] = torch.softmax(output["logits"], dim=-1)
else:
output["probabilities"] = torch.sigmoid(output["logits"])
if len(cfg.dataset.answer_column) == 1:
if cfg.dataset.num_classes == 1:
output["predictions"] = (output["probabilities"] > 0.5).long()
else:
output["predictions"] = output["probabilities"].argmax(
dim=-1, keepdim=True
)
else:
output["predictions"] = (output["probabilities"] > 0.5).long()
preds = []
for col in np.arange(output["probabilities"].shape[1]):
preds.append(
np.round(output["probabilities"][:, col].cpu().numpy(), 3).astype(str)
)
preds = [",".join(pred) for pred in zip(*preds)]
output["predicted_text"] = preds
return super().postprocess_output(cfg, df, output)
def clean_output(self, output, cfg):
return output
@classmethod
def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):
for answer_col in cfg.dataset.answer_column:
assert answer_col in df.columns, (
f"Answer column {answer_col} not found in the " f"{mode} DataFrame."
)
assert df.shape[0] == df[answer_col].dropna().shape[0], (
f"The {mode} DataFrame"
f" column {answer_col}"
" contains missing values."
)
check_for_non_int_answers(cfg, df)
def check_for_non_int_answers(cfg, df):
answers_non_int: list = []
for column in cfg.dataset.answer_column:
answers_non_int.extend(
x for x in df[column].values if not is_castable_to_int(x)
)
if len(answers_non_int) > 0:
raise LLMDataException(
f"Column {cfg.dataset.answer_column} contains non int items. "
f"Sample values: {answers_non_int[:5]}."
)
def is_castable_to_int(s):
try:
int(s)
return True
except ValueError:
return False