Spaces:
Sleeping
Sleeping
import logging | |
from typing import Any, Dict | |
import numpy as np | |
import pandas as pd | |
import torch | |
from llm_studio.src.datasets.text_causal_language_modeling_ds import ( | |
CustomDataset as TextCausalLanguageModelingCustomDataset, | |
) | |
from llm_studio.src.utils.exceptions import LLMDataException | |
logger = logging.getLogger(__name__) | |
class CustomDataset(TextCausalLanguageModelingCustomDataset): | |
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): | |
super().__init__(df=df, cfg=cfg, mode=mode) | |
check_for_non_int_answers(cfg, df) | |
self.answers_int = df[cfg.dataset.answer_column].astype(int).values | |
max_value = np.max(self.answers_int) | |
min_value = np.min(self.answers_int) | |
if 1 < cfg.dataset.num_classes <= max_value: | |
raise LLMDataException( | |
"Number of classes is smaller than max label " | |
f"{max_value}. Please increase the setting accordingly." | |
) | |
elif cfg.dataset.num_classes == 1 and max_value > 1: | |
raise LLMDataException( | |
"For binary classification, max label should be 1 but is " | |
f"{max_value}." | |
) | |
if min_value < 0: | |
raise LLMDataException( | |
"Labels should be non-negative but min label is " f"{min_value}." | |
) | |
if min_value != 0 or max_value != np.unique(self.answers_int).size - 1: | |
logger.warning( | |
"Labels should start at 0 and be continuous but are " | |
f"{sorted(np.unique(self.answers_int))}." | |
) | |
if cfg.dataset.parent_id_column != "None": | |
raise LLMDataException( | |
"Parent ID column is not supported for classification datasets." | |
) | |
def __getitem__(self, idx: int) -> Dict: | |
sample = super().__getitem__(idx) | |
sample["class_label"] = self.answers_int[idx] | |
return sample | |
def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict: | |
output["logits"] = output["logits"].float() | |
if cfg.training.loss_function == "CrossEntropyLoss": | |
output["probabilities"] = torch.softmax(output["logits"], dim=-1) | |
else: | |
output["probabilities"] = torch.sigmoid(output["logits"]) | |
if len(cfg.dataset.answer_column) == 1: | |
if cfg.dataset.num_classes == 1: | |
output["predictions"] = (output["probabilities"] > 0.5).long() | |
else: | |
output["predictions"] = output["probabilities"].argmax( | |
dim=-1, keepdim=True | |
) | |
else: | |
output["predictions"] = (output["probabilities"] > 0.5).long() | |
preds = [] | |
for col in np.arange(output["probabilities"].shape[1]): | |
preds.append( | |
np.round(output["probabilities"][:, col].cpu().numpy(), 3).astype(str) | |
) | |
preds = [",".join(pred) for pred in zip(*preds)] | |
output["predicted_text"] = preds | |
return super().postprocess_output(cfg, df, output) | |
def clean_output(self, output, cfg): | |
return output | |
def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"): | |
for answer_col in cfg.dataset.answer_column: | |
assert answer_col in df.columns, ( | |
f"Answer column {answer_col} not found in the " f"{mode} DataFrame." | |
) | |
assert df.shape[0] == df[answer_col].dropna().shape[0], ( | |
f"The {mode} DataFrame" | |
f" column {answer_col}" | |
" contains missing values." | |
) | |
check_for_non_int_answers(cfg, df) | |
def check_for_non_int_answers(cfg, df): | |
answers_non_int: list = [] | |
for column in cfg.dataset.answer_column: | |
answers_non_int.extend( | |
x for x in df[column].values if not is_castable_to_int(x) | |
) | |
if len(answers_non_int) > 0: | |
raise LLMDataException( | |
f"Column {cfg.dataset.answer_column} contains non int items. " | |
f"Sample values: {answers_non_int[:5]}." | |
) | |
def is_castable_to_int(s): | |
try: | |
int(s) | |
return True | |
except ValueError: | |
return False | |