Natooz commited on
Commit
78003b9
·
unverified ·
1 Parent(s): 4fb9244

code style update

Browse files
Files changed (1) hide show
  1. ece.py +16 -7
ece.py CHANGED
@@ -17,7 +17,10 @@ from typing import Dict
17
  import evaluate
18
  import datasets
19
  from torch import from_numpy, amax
20
- from torchmetrics.functional.classification.calibration_error import binary_calibration_error, multiclass_calibration_error
 
 
 
21
  from numpy import ndarray
22
 
23
 
@@ -75,15 +78,21 @@ class ECE(evaluate.Metric):
75
  citation=_CITATION,
76
  inputs_description=_KWARGS_DESCRIPTION,
77
  # This defines the format of each prediction and reference
78
- features=datasets.Features({
79
- 'predictions': datasets.Sequence(datasets.Value('float32')),
80
- 'references': datasets.Value('int64'),
81
- }),
 
 
82
  # Homepage of the module for documentation
83
  homepage="https://huggingface.co/spaces/Natooz/ece",
84
  # Additional links to the codebase or references
85
- codebase_urls=["https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py"],
86
- reference_urls=["https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html"]
 
 
 
 
87
  )
88
 
89
  def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]:
 
17
  import evaluate
18
  import datasets
19
  from torch import from_numpy, amax
20
+ from torchmetrics.functional.classification.calibration_error import (
21
+ binary_calibration_error,
22
+ multiclass_calibration_error,
23
+ )
24
  from numpy import ndarray
25
 
26
 
 
78
  citation=_CITATION,
79
  inputs_description=_KWARGS_DESCRIPTION,
80
  # This defines the format of each prediction and reference
81
+ features=datasets.Features(
82
+ {
83
+ "predictions": datasets.Sequence(datasets.Value("float32")),
84
+ "references": datasets.Value("int64"),
85
+ }
86
+ ),
87
  # Homepage of the module for documentation
88
  homepage="https://huggingface.co/spaces/Natooz/ece",
89
  # Additional links to the codebase or references
90
+ codebase_urls=[
91
+ "https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py"
92
+ ],
93
+ reference_urls=[
94
+ "https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html"
95
+ ],
96
  )
97
 
98
  def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]: