patch_series / patch_series.py
bowdbeg's picture
implemented
6d01d6a
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""
import logging
from typing import List, Optional, Union
import datasets
import evaluate
import numpy as np
logger = logging.getLogger(__name__)
# TODO: Add BibTeX citation
_CITATION = """\
@InProceedings{huggingface:module,
title = {A great new module},
authors={huggingface, Inc.},
year={2020}
}
"""
# TODO: Add description of the module here
_DESCRIPTION = """\
This new module is designed to solve this great ML task and is crafted with a lot of care.
"""
# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predictions to score. Each predictions
should be a string with tokens separated by spaces.
references: list of reference for each prediction. Each
reference should be a string with tokens separated by spaces.
Returns:
accuracy: description of the first score,
another_score: description of the second score,
Examples:
Examples should be written in doctest format, and should illustrate how
to use the function.
>>> my_new_module = evaluate.load("my_new_module")
>>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
>>> print(results)
{'accuracy': 1.0}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class patch_series(evaluate.Metric):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.matching_series_metric = evaluate.load("bowdbeg/matching_series")
def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features(
{
"predictions": datasets.Value("int64"),
"references": datasets.Value("int64"),
}
),
# Homepage of the module for documentation
homepage="http://module.homepage",
# Additional links to the codebase or references
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
reference_urls=["http://path.to.reference.url/new_module"],
)
def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[dict]:
""""""
all_kwargs = {"predictions": predictions, "references": references, **kwargs}
if predictions is None and references is None:
missing_kwargs = {k: None for k in self._feature_names() if k not in all_kwargs}
all_kwargs.update(missing_kwargs)
else:
missing_inputs = [k for k in self._feature_names() if k not in all_kwargs]
if missing_inputs:
raise ValueError(
f"Evaluation module inputs are missing: {missing_inputs}. All required inputs are {list(self._feature_names())}"
)
inputs = {input_name: all_kwargs[input_name] for input_name in self._feature_names()}
compute_kwargs = {k: kwargs[k] for k in kwargs if k not in self._feature_names()}
return self._compute(**inputs, **compute_kwargs)
def _compute(
self,
predictions: Union[List, np.ndarray],
references: Union[List, np.ndarray],
patch_length: List[int] = [1],
strides: Union[List[int], None] = None,
**kwargs,
):
"""Compute the evaluation score for bowdbeg/matching_series for each patch and take mean."""
if strides is None:
strides = patch_length
assert len(patch_length) == len(strides), "The patch_length and strides should have the same length."
predictions = np.array(predictions)
references = np.array(references)
if not all(predictions.shape[1] % p == 0 for p in patch_length) and not all(
references.shape[1] % p == 0 for p in patch_length
):
raise ValueError("The patch_length should divide the length of the predictions and references.")
if len(predictions.shape) != 3:
raise ValueError("Predictions should have shape (batch_size, sequence_length, num_features)")
if len(patch_length) == 0:
raise ValueError("The patch_length should be a list of integers.")
res_sum: Union[None, dict] = None
orig_pred_shape = predictions.shape
orig_ref_shape = references.shape
for patch, stride in zip(patch_length, strides):
# create patched predictions and references
patched_predictions = self.get_patches(predictions, patch, stride, axis=1)
patched_references = self.get_patches(references, patch, stride, axis=1)
patched_predictions = patched_predictions.reshape(-1, patch, orig_pred_shape[2])
patched_references = patched_references.reshape(-1, patch, orig_ref_shape[2])
# compute the score for each patch
res = self.matching_series_metric.compute(
predictions=patched_predictions, references=patched_references, **kwargs
)
# sum the results
if res_sum is None:
res_sum = res
else:
assert isinstance(res_sum, dict)
assert isinstance(res, dict)
for key in res_sum:
if isinstance(res_sum[key], (list, np.ndarray)):
res_sum[key] = np.array(res_sum[key]) + np.array(res[key])
elif isinstance(res_sum[key], (float, int)):
res_sum[key] += res[key]
else:
logger.warning(f"Unsupported type for key {key}: {type(res_sum[key])}")
del res_sum[key]
# take the mean of the results
assert isinstance(res_sum, dict)
for key in res_sum:
if isinstance(res_sum[key], (list, np.ndarray)):
res_sum[key] = np.array(res_sum[key]) / len(patch_length)
else:
res_sum[key] /= len(patch_length)
return res_sum
@staticmethod
def get_patches(series: np.ndarray, patch_length: int, stride: int, axis=0):
# create patched predictions and references
o = np.lib.stride_tricks.sliding_window_view(series, window_shape=patch_length, axis=axis)
o = o[::stride]
return o