Spaces:
Sleeping
Sleeping
implemented
Browse files- .gitignore +133 -0
- __main__.py +56 -0
- patch_series.py +99 -22
.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode
|
2 |
+
data/
|
3 |
+
output/
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
pip-wheel-metadata/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
.python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
99 |
+
__pypackages__/
|
100 |
+
|
101 |
+
# Celery stuff
|
102 |
+
celerybeat-schedule
|
103 |
+
celerybeat.pid
|
104 |
+
|
105 |
+
# SageMath parsed files
|
106 |
+
*.sage.py
|
107 |
+
|
108 |
+
# Environments
|
109 |
+
.env
|
110 |
+
.venv
|
111 |
+
env/
|
112 |
+
venv/
|
113 |
+
ENV/
|
114 |
+
env.bak/
|
115 |
+
venv.bak/
|
116 |
+
|
117 |
+
# Spyder project settings
|
118 |
+
.spyderproject
|
119 |
+
.spyproject
|
120 |
+
|
121 |
+
# Rope project settings
|
122 |
+
.ropeproject
|
123 |
+
|
124 |
+
# mkdocs documentation
|
125 |
+
/site
|
126 |
+
|
127 |
+
# mypy
|
128 |
+
.mypy_cache/
|
129 |
+
.dmypy.json
|
130 |
+
dmypy.json
|
131 |
+
|
132 |
+
# Pyre type checker
|
133 |
+
.pyre/
|
__main__.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
import evaluate
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
parser = ArgumentParser(
|
12 |
+
description="Compute the matching series score between two time series freezed in a numpy array"
|
13 |
+
)
|
14 |
+
parser.add_argument("predictions", type=str, help="Path to the numpy array containing the predictions")
|
15 |
+
parser.add_argument("references", type=str, help="Path to the numpy array containing the references")
|
16 |
+
parser.add_argument("--output", type=str, help="Path to the output file")
|
17 |
+
parser.add_argument("--batch_size", type=int, help="Batch size to use for the computation")
|
18 |
+
parser.add_argument("--num_processes", type=int, help="Batch size to use for the computation", default=1)
|
19 |
+
parser.add_argument("--dtype", type=str, help="Data type to use for the computation", default="float32")
|
20 |
+
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
if not args.predictions or not args.references:
|
24 |
+
raise ValueError("You must provide the path to the predictions and references numpy arrays")
|
25 |
+
|
26 |
+
|
27 |
+
predictions = np.load(args.predictions).astype(args.dtype)
|
28 |
+
references = np.load(args.references).astype(args.dtype)
|
29 |
+
|
30 |
+
if args.debug:
|
31 |
+
predictions = predictions[:1000]
|
32 |
+
references = references[:1000]
|
33 |
+
|
34 |
+
logger.info(f"predictions shape: {predictions.shape}")
|
35 |
+
logger.info(f"references shape: {references.shape}")
|
36 |
+
|
37 |
+
import patch_series
|
38 |
+
|
39 |
+
s = time.time()
|
40 |
+
metric = patch_series.patch_series()
|
41 |
+
# metric = evaluate.load("patch_series.py")
|
42 |
+
results = metric.compute(
|
43 |
+
predictions=predictions,
|
44 |
+
references=references,
|
45 |
+
batch_size=args.batch_size,
|
46 |
+
num_processes=args.num_process,
|
47 |
+
return_each_features=True,
|
48 |
+
return_coverages=True,
|
49 |
+
dtype=args.dtype,
|
50 |
+
)
|
51 |
+
logger.info(f"Time taken: {time.time() - s}")
|
52 |
+
|
53 |
+
print(json.dumps(results))
|
54 |
+
if args.output:
|
55 |
+
with open(args.output, "w") as f:
|
56 |
+
json.dump(results, f)
|
patch_series.py
CHANGED
@@ -13,9 +13,14 @@
|
|
13 |
# limitations under the License.
|
14 |
"""TODO: Add a description here."""
|
15 |
|
16 |
-
import
|
|
|
|
|
17 |
import datasets
|
|
|
|
|
18 |
|
|
|
19 |
|
20 |
# TODO: Add BibTeX citation
|
21 |
_CITATION = """\
|
@@ -53,13 +58,13 @@ Examples:
|
|
53 |
{'accuracy': 1.0}
|
54 |
"""
|
55 |
|
56 |
-
# TODO: Define external resources urls if needed
|
57 |
-
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
58 |
-
|
59 |
|
60 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
61 |
class patch_series(evaluate.Metric):
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
def _info(self):
|
65 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
@@ -70,26 +75,98 @@ class patch_series(evaluate.Metric):
|
|
70 |
citation=_CITATION,
|
71 |
inputs_description=_KWARGS_DESCRIPTION,
|
72 |
# This defines the format of each prediction and reference
|
73 |
-
features=datasets.Features(
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
77 |
# Homepage of the module for documentation
|
78 |
homepage="http://module.homepage",
|
79 |
# Additional links to the codebase or references
|
80 |
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
|
81 |
-
reference_urls=["http://path.to.reference.url/new_module"]
|
82 |
)
|
83 |
|
84 |
-
def
|
85 |
-
"""
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# limitations under the License.
|
14 |
"""TODO: Add a description here."""
|
15 |
|
16 |
+
import logging
|
17 |
+
from typing import List, Optional, Union
|
18 |
+
|
19 |
import datasets
|
20 |
+
import evaluate
|
21 |
+
import numpy as np
|
22 |
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
|
25 |
# TODO: Add BibTeX citation
|
26 |
_CITATION = """\
|
|
|
58 |
{'accuracy': 1.0}
|
59 |
"""
|
60 |
|
|
|
|
|
|
|
61 |
|
62 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
63 |
class patch_series(evaluate.Metric):
|
64 |
+
|
65 |
+
def __init__(self, *args, **kwargs):
|
66 |
+
super().__init__(*args, **kwargs)
|
67 |
+
self.matching_series_metric = evaluate.load("bowdbeg/matching_series")
|
68 |
|
69 |
def _info(self):
|
70 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
|
|
75 |
citation=_CITATION,
|
76 |
inputs_description=_KWARGS_DESCRIPTION,
|
77 |
# This defines the format of each prediction and reference
|
78 |
+
features=datasets.Features(
|
79 |
+
{
|
80 |
+
"predictions": datasets.Value("int64"),
|
81 |
+
"references": datasets.Value("int64"),
|
82 |
+
}
|
83 |
+
),
|
84 |
# Homepage of the module for documentation
|
85 |
homepage="http://module.homepage",
|
86 |
# Additional links to the codebase or references
|
87 |
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
|
88 |
+
reference_urls=["http://path.to.reference.url/new_module"],
|
89 |
)
|
90 |
|
91 |
+
def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[dict]:
|
92 |
+
""""""
|
93 |
+
all_kwargs = {"predictions": predictions, "references": references, **kwargs}
|
94 |
+
if predictions is None and references is None:
|
95 |
+
missing_kwargs = {k: None for k in self._feature_names() if k not in all_kwargs}
|
96 |
+
all_kwargs.update(missing_kwargs)
|
97 |
+
else:
|
98 |
+
missing_inputs = [k for k in self._feature_names() if k not in all_kwargs]
|
99 |
+
if missing_inputs:
|
100 |
+
raise ValueError(
|
101 |
+
f"Evaluation module inputs are missing: {missing_inputs}. All required inputs are {list(self._feature_names())}"
|
102 |
+
)
|
103 |
+
inputs = {input_name: all_kwargs[input_name] for input_name in self._feature_names()}
|
104 |
+
compute_kwargs = {k: kwargs[k] for k in kwargs if k not in self._feature_names()}
|
105 |
+
return self._compute(**inputs, **compute_kwargs)
|
106 |
+
|
107 |
+
def _compute(
|
108 |
+
self,
|
109 |
+
predictions: Union[List, np.ndarray],
|
110 |
+
references: Union[List, np.ndarray],
|
111 |
+
patch_length: List[int] = [1],
|
112 |
+
strides: Union[List[int], None] = None,
|
113 |
+
**kwargs,
|
114 |
+
):
|
115 |
+
"""Compute the evaluation score for bowdbeg/matching_series for each patch and take mean."""
|
116 |
+
if strides is None:
|
117 |
+
strides = patch_length
|
118 |
+
assert len(patch_length) == len(strides), "The patch_length and strides should have the same length."
|
119 |
+
predictions = np.array(predictions)
|
120 |
+
references = np.array(references)
|
121 |
+
if not all(predictions.shape[1] % p == 0 for p in patch_length) and not all(
|
122 |
+
references.shape[1] % p == 0 for p in patch_length
|
123 |
+
):
|
124 |
+
raise ValueError("The patch_length should divide the length of the predictions and references.")
|
125 |
+
if len(predictions.shape) != 3:
|
126 |
+
raise ValueError("Predictions should have shape (batch_size, sequence_length, num_features)")
|
127 |
+
if len(patch_length) == 0:
|
128 |
+
raise ValueError("The patch_length should be a list of integers.")
|
129 |
+
res_sum: Union[None, dict] = None
|
130 |
+
orig_pred_shape = predictions.shape
|
131 |
+
orig_ref_shape = references.shape
|
132 |
+
for patch, stride in zip(patch_length, strides):
|
133 |
+
# create patched predictions and references
|
134 |
+
patched_predictions = self.get_patches(predictions, patch, stride, axis=1)
|
135 |
+
patched_references = self.get_patches(references, patch, stride, axis=1)
|
136 |
+
patched_predictions = patched_predictions.reshape(-1, patch, orig_pred_shape[2])
|
137 |
+
patched_references = patched_references.reshape(-1, patch, orig_ref_shape[2])
|
138 |
+
|
139 |
+
# compute the score for each patch
|
140 |
+
res = self.matching_series_metric.compute(
|
141 |
+
predictions=patched_predictions, references=patched_references, **kwargs
|
142 |
+
)
|
143 |
+
# sum the results
|
144 |
+
if res_sum is None:
|
145 |
+
res_sum = res
|
146 |
+
else:
|
147 |
+
assert isinstance(res_sum, dict)
|
148 |
+
assert isinstance(res, dict)
|
149 |
+
for key in res_sum:
|
150 |
+
if isinstance(res_sum[key], (list, np.ndarray)):
|
151 |
+
res_sum[key] = np.array(res_sum[key]) + np.array(res[key])
|
152 |
+
elif isinstance(res_sum[key], (float, int)):
|
153 |
+
res_sum[key] += res[key]
|
154 |
+
else:
|
155 |
+
logger.warning(f"Unsupported type for key {key}: {type(res_sum[key])}")
|
156 |
+
del res_sum[key]
|
157 |
+
# take the mean of the results
|
158 |
+
assert isinstance(res_sum, dict)
|
159 |
+
for key in res_sum:
|
160 |
+
if isinstance(res_sum[key], (list, np.ndarray)):
|
161 |
+
res_sum[key] = np.array(res_sum[key]) / len(patch_length)
|
162 |
+
else:
|
163 |
+
res_sum[key] /= len(patch_length)
|
164 |
+
|
165 |
+
return res_sum
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def get_patches(series: np.ndarray, patch_length: int, stride: int, axis=0):
|
169 |
+
# create patched predictions and references
|
170 |
+
o = np.lib.stride_tricks.sliding_window_view(series, window_shape=patch_length, axis=axis)
|
171 |
+
o = o[::stride]
|
172 |
+
return o
|