Upload operators.py with huggingface_hub
Browse files- operators.py +709 -162
operators.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import collections
|
| 2 |
import importlib
|
| 3 |
-
import inspect
|
| 4 |
import uuid
|
| 5 |
from abc import abstractmethod
|
|
|
|
| 6 |
from copy import deepcopy
|
| 7 |
from dataclasses import field
|
| 8 |
from itertools import zip_longest
|
|
@@ -19,7 +19,7 @@ from typing import (
|
|
| 19 |
)
|
| 20 |
|
| 21 |
from .artifact import Artifact, fetch_artifact
|
| 22 |
-
from .dataclass import NonPositionalField
|
| 23 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 24 |
from .operator import (
|
| 25 |
MultiStream,
|
|
@@ -32,15 +32,14 @@ from .operator import (
|
|
| 32 |
StreamInstanceOperator,
|
| 33 |
StreamSource,
|
| 34 |
)
|
| 35 |
-
from .random_utils import
|
| 36 |
-
from .stream import
|
| 37 |
from .text_utils import nested_tuple_to_string
|
| 38 |
from .utils import flatten_dict
|
| 39 |
|
| 40 |
|
| 41 |
class FromIterables(StreamInitializerOperator):
|
| 42 |
-
"""
|
| 43 |
-
Creates a MultiStream from iterables.
|
| 44 |
|
| 45 |
Args:
|
| 46 |
iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
|
|
@@ -70,35 +69,83 @@ class MapInstanceValues(StreamInstanceOperator):
|
|
| 70 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
| 71 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
| 72 |
that are not present in the mapper are kept as they are.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
"""
|
| 74 |
|
| 75 |
mappers: Dict[str, Dict[str, str]]
|
| 76 |
strict: bool = True
|
| 77 |
-
use_query = False
|
|
|
|
| 78 |
|
| 79 |
def verify(self):
|
| 80 |
# make sure the mappers are valid
|
| 81 |
for key, mapper in self.mappers.items():
|
| 82 |
-
assert isinstance(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
for key, mapper in self.mappers.items():
|
| 88 |
value = dict_get(instance, key, use_dpath=self.use_query)
|
| 89 |
if value is not None:
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
if value in mapper:
|
| 95 |
dict_set(instance, key, mapper[value], use_dpath=self.use_query)
|
|
|
|
| 96 |
return instance
|
| 97 |
|
| 98 |
|
| 99 |
class FlattenInstances(StreamInstanceOperator):
|
| 100 |
-
"""
|
| 101 |
-
Flattens each instance in a stream, making nested dictionary entries into top-level entries.
|
| 102 |
|
| 103 |
Args:
|
| 104 |
parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
|
|
@@ -108,23 +155,42 @@ class FlattenInstances(StreamInstanceOperator):
|
|
| 108 |
parent_key: str = ""
|
| 109 |
sep: str = "_"
|
| 110 |
|
| 111 |
-
def process(
|
|
|
|
|
|
|
| 112 |
return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
|
| 113 |
|
| 114 |
|
| 115 |
class AddFields(StreamInstanceOperator):
|
| 116 |
-
"""
|
| 117 |
-
Adds specified fields to each instance in a stream.
|
| 118 |
|
| 119 |
Args:
|
| 120 |
fields (Dict[str, object]): The fields to add to each instance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
"""
|
| 122 |
|
| 123 |
fields: Dict[str, object]
|
| 124 |
use_query: bool = False
|
| 125 |
use_deepcopy: bool = False
|
| 126 |
|
| 127 |
-
def process(
|
|
|
|
|
|
|
| 128 |
if self.use_query:
|
| 129 |
for key, value in self.fields.items():
|
| 130 |
if self.use_deepcopy:
|
|
@@ -138,30 +204,31 @@ class AddFields(StreamInstanceOperator):
|
|
| 138 |
|
| 139 |
|
| 140 |
class RemoveFields(StreamInstanceOperator):
|
| 141 |
-
"""
|
| 142 |
-
Adds specified fields to each instance in a stream.
|
| 143 |
|
| 144 |
Args:
|
| 145 |
-
fields (
|
| 146 |
"""
|
| 147 |
|
| 148 |
fields: List[str]
|
| 149 |
|
| 150 |
-
def process(
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
return instance
|
| 154 |
|
| 155 |
|
| 156 |
class FieldOperator(StreamInstanceOperator):
|
| 157 |
-
"""
|
| 158 |
-
|
| 159 |
Args:
|
| 160 |
field (Optional[str]): The field to process, if only a single one is passed Defaults to None
|
| 161 |
to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
|
| 162 |
field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
|
| 163 |
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
|
| 164 |
-
use_query (bool): Whether to use dpath style queries. Defaults to False
|
| 165 |
"""
|
| 166 |
|
| 167 |
field: Optional[str] = None
|
|
@@ -175,14 +242,18 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 175 |
def verify(self):
|
| 176 |
super().verify()
|
| 177 |
|
| 178 |
-
assert
|
|
|
|
|
|
|
| 179 |
assert (
|
| 180 |
self.to_field is None or self.field_to_field is None
|
| 181 |
), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
|
| 182 |
assert (
|
| 183 |
self.field is None or self.field_to_field is None
|
| 184 |
), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
|
| 185 |
-
assert
|
|
|
|
|
|
|
| 186 |
|
| 187 |
@abstractmethod
|
| 188 |
def process_value(self, value: Any) -> Any:
|
|
@@ -195,11 +266,13 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 195 |
self._field_to_field = [(self.field, self.to_field)]
|
| 196 |
else:
|
| 197 |
try:
|
| 198 |
-
self._field_to_field =
|
| 199 |
except AttributeError:
|
| 200 |
self._field_to_field = self.field_to_field
|
| 201 |
|
| 202 |
-
def process(
|
|
|
|
|
|
|
| 203 |
for from_field, to_field in self._field_to_field:
|
| 204 |
try:
|
| 205 |
old_value = dict_get(
|
|
@@ -209,27 +282,40 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 209 |
default=self.get_default,
|
| 210 |
not_exist_ok=self.not_exist_ok,
|
| 211 |
)
|
| 212 |
-
except
|
| 213 |
-
raise
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
if self.use_query and is_subpath(from_field, to_field):
|
| 219 |
dict_delete(instance, from_field)
|
| 220 |
-
dict_set(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
return instance
|
| 222 |
|
| 223 |
|
| 224 |
class RenameFields(FieldOperator):
|
| 225 |
-
"""
|
| 226 |
-
Renames fields
|
| 227 |
-
"""
|
| 228 |
|
| 229 |
def process_value(self, value: Any) -> Any:
|
| 230 |
return value
|
| 231 |
|
| 232 |
-
def process(
|
|
|
|
|
|
|
| 233 |
res = super().process(instance=instance, stream_name=stream_name)
|
| 234 |
vals = [x[1] for x in self._field_to_field]
|
| 235 |
for key, _ in self._field_to_field:
|
|
@@ -241,32 +327,202 @@ class RenameFields(FieldOperator):
|
|
| 241 |
|
| 242 |
|
| 243 |
class AddConstant(FieldOperator):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
"""
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
Args:
|
| 247 |
-
|
|
|
|
|
|
|
| 248 |
"""
|
| 249 |
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
|
|
|
| 252 |
def process_value(self, value: Any) -> Any:
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
"""
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
"""
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
def process_value(self, value: Any) -> Any:
|
| 262 |
res = list(value)
|
| 263 |
-
|
| 264 |
return res
|
| 265 |
|
| 266 |
|
| 267 |
class JoinStr(FieldOperator):
|
| 268 |
-
"""
|
| 269 |
-
|
| 270 |
Args:
|
| 271 |
separator (str): text to put between values
|
| 272 |
"""
|
|
@@ -278,6 +534,25 @@ class JoinStr(FieldOperator):
|
|
| 278 |
|
| 279 |
|
| 280 |
class Apply(StreamInstanceOperator):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
__allow_unexpected_arguments__ = True
|
| 282 |
function: Callable = NonPositionalField(required=True)
|
| 283 |
to_field: str = NonPositionalField(required=True)
|
|
@@ -292,25 +567,23 @@ class Apply(StreamInstanceOperator):
|
|
| 292 |
else:
|
| 293 |
parts.append(function.__name__)
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
return result
|
| 298 |
|
| 299 |
def str_to_function(self, function_str: str) -> Callable:
|
| 300 |
splitted = function_str.split(".", 1)
|
| 301 |
if len(splitted) == 1:
|
| 302 |
-
return __builtins__[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
else:
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
obj = globals()[module_name]
|
| 309 |
-
else:
|
| 310 |
-
obj = importlib.import_module(module_name)
|
| 311 |
-
for part in function_name.split("."):
|
| 312 |
-
obj = getattr(obj, part)
|
| 313 |
-
return obj
|
| 314 |
|
| 315 |
def prepare(self):
|
| 316 |
super().prepare()
|
|
@@ -318,7 +591,9 @@ class Apply(StreamInstanceOperator):
|
|
| 318 |
self.function = self.str_to_function(self.function)
|
| 319 |
self._init_dict["function"] = self.function_to_str(self.function)
|
| 320 |
|
| 321 |
-
def process(
|
|
|
|
|
|
|
| 322 |
argv = [instance[arg] for arg in self._argv]
|
| 323 |
kwargs = {key: instance[val] for key, val in self._kwargs}
|
| 324 |
|
|
@@ -329,36 +604,36 @@ class Apply(StreamInstanceOperator):
|
|
| 329 |
|
| 330 |
|
| 331 |
class ListFieldValues(StreamInstanceOperator):
|
| 332 |
-
"""
|
| 333 |
-
Concatanates values of multiple fields into a list to list(fields)
|
| 334 |
-
"""
|
| 335 |
|
| 336 |
-
fields: str
|
| 337 |
to_field: str
|
| 338 |
use_query: bool = False
|
| 339 |
|
| 340 |
-
def process(
|
|
|
|
|
|
|
| 341 |
values = []
|
| 342 |
-
for
|
| 343 |
-
values.append(dict_get(instance,
|
| 344 |
instance[self.to_field] = values
|
| 345 |
return instance
|
| 346 |
|
| 347 |
|
| 348 |
class ZipFieldValues(StreamInstanceOperator):
|
| 349 |
-
"""
|
| 350 |
-
Zips values of multiple fields similar to list(zip(*fields))
|
| 351 |
-
"""
|
| 352 |
|
| 353 |
fields: str
|
| 354 |
to_field: str
|
| 355 |
longest: bool = False
|
| 356 |
use_query: bool = False
|
| 357 |
|
| 358 |
-
def process(
|
|
|
|
|
|
|
| 359 |
values = []
|
| 360 |
-
for
|
| 361 |
-
values.append(dict_get(instance,
|
| 362 |
if self.longest:
|
| 363 |
zipped = zip_longest(*values)
|
| 364 |
else:
|
|
@@ -368,16 +643,16 @@ class ZipFieldValues(StreamInstanceOperator):
|
|
| 368 |
|
| 369 |
|
| 370 |
class IndexOf(StreamInstanceOperator):
|
| 371 |
-
"""
|
| 372 |
-
Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)
|
| 373 |
-
"""
|
| 374 |
|
| 375 |
search_in: str
|
| 376 |
index_of: str
|
| 377 |
to_field: str
|
| 378 |
use_query: bool = False
|
| 379 |
|
| 380 |
-
def process(
|
|
|
|
|
|
|
| 381 |
lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
|
| 382 |
item = dict_get(instance, self.index_of, use_dpath=self.use_query)
|
| 383 |
instance[self.to_field] = lst.index(item)
|
|
@@ -385,9 +660,7 @@ class IndexOf(StreamInstanceOperator):
|
|
| 385 |
|
| 386 |
|
| 387 |
class TakeByField(StreamInstanceOperator):
|
| 388 |
-
"""
|
| 389 |
-
Takes value from one field based on another field similar to field[index]
|
| 390 |
-
"""
|
| 391 |
|
| 392 |
field: str
|
| 393 |
index: str
|
|
@@ -398,7 +671,9 @@ class TakeByField(StreamInstanceOperator):
|
|
| 398 |
if self.to_field is None:
|
| 399 |
self.to_field = self.field
|
| 400 |
|
| 401 |
-
def process(
|
|
|
|
|
|
|
| 402 |
value = dict_get(instance, self.field, use_dpath=self.use_query)
|
| 403 |
index_value = dict_get(instance, self.index, use_dpath=self.use_query)
|
| 404 |
instance[self.to_field] = value[index_value]
|
|
@@ -406,8 +681,7 @@ class TakeByField(StreamInstanceOperator):
|
|
| 406 |
|
| 407 |
|
| 408 |
class CopyFields(FieldOperator):
|
| 409 |
-
"""
|
| 410 |
-
Copies specified fields from one field to another.
|
| 411 |
|
| 412 |
Args:
|
| 413 |
field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
|
|
@@ -421,14 +695,15 @@ class CopyFields(FieldOperator):
|
|
| 421 |
class AddID(StreamInstanceOperator):
|
| 422 |
id_field_name: str = "id"
|
| 423 |
|
| 424 |
-
def process(
|
|
|
|
|
|
|
| 425 |
instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
|
| 426 |
return instance
|
| 427 |
|
| 428 |
|
| 429 |
class CastFields(StreamInstanceOperator):
|
| 430 |
-
"""
|
| 431 |
-
Casts specified fields to specified types.
|
| 432 |
|
| 433 |
Args:
|
| 434 |
types (Dict[str, str]): A dictionary mapping fields to their new types.
|
|
@@ -451,24 +726,28 @@ class CastFields(StreamInstanceOperator):
|
|
| 451 |
def _cast_single(self, value, type, field):
|
| 452 |
try:
|
| 453 |
return self.types[type](value)
|
| 454 |
-
except:
|
| 455 |
if field not in self.failure_defaults:
|
| 456 |
raise ValueError(
|
| 457 |
f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
|
| 458 |
-
)
|
| 459 |
return self.failure_defaults[field]
|
| 460 |
|
| 461 |
def _cast_multiple(self, values, type, field):
|
| 462 |
values = [self._cast_single(value, type, field) for value in values]
|
| 463 |
|
| 464 |
-
def process(
|
| 465 |
-
|
| 466 |
-
|
|
|
|
|
|
|
| 467 |
if self.cast_multiple:
|
| 468 |
-
casted_value = self._cast_multiple(value, type,
|
| 469 |
else:
|
| 470 |
-
casted_value = self._cast_single(value, type,
|
| 471 |
-
dict_set(
|
|
|
|
|
|
|
| 472 |
return instance
|
| 473 |
|
| 474 |
|
|
@@ -491,13 +770,14 @@ class DivideAllFieldsBy(StreamInstanceOperator):
|
|
| 491 |
strict: bool = False
|
| 492 |
recursive: bool = True
|
| 493 |
|
| 494 |
-
def process(
|
|
|
|
|
|
|
| 495 |
return recursive_divide(instance, self.divisor, strict=self.strict)
|
| 496 |
|
| 497 |
|
| 498 |
class ArtifactFetcherMixin:
|
| 499 |
-
"""
|
| 500 |
-
Provides a way to fetch and cache artifacts in the system.
|
| 501 |
|
| 502 |
Args:
|
| 503 |
cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
|
|
@@ -514,8 +794,7 @@ class ArtifactFetcherMixin:
|
|
| 514 |
|
| 515 |
|
| 516 |
class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
| 517 |
-
"""
|
| 518 |
-
Applies value operators to each instance in a stream based on specified fields.
|
| 519 |
|
| 520 |
Args:
|
| 521 |
value_field (str): The field containing the value to be operated on.
|
|
@@ -529,7 +808,9 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
|
| 529 |
default_operators: List[str] = None
|
| 530 |
fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
|
| 531 |
|
| 532 |
-
def process(
|
|
|
|
|
|
|
| 533 |
operator_names = instance.get(self.operators_field)
|
| 534 |
if operator_names is None:
|
| 535 |
assert (
|
|
@@ -542,35 +823,228 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
|
| 542 |
|
| 543 |
for name in operator_names:
|
| 544 |
operator = self.get_artifact(name)
|
| 545 |
-
for
|
| 546 |
-
value = instance[
|
| 547 |
-
if
|
| 548 |
-
instance[
|
| 549 |
else:
|
| 550 |
-
instance[
|
| 551 |
|
| 552 |
return instance
|
| 553 |
|
| 554 |
|
| 555 |
class FilterByValues(SingleStreamOperator):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
"""
|
| 557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
Args:
|
| 560 |
-
|
| 561 |
"""
|
| 562 |
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
| 566 |
for instance in stream:
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
yield instance
|
| 569 |
|
| 570 |
|
| 571 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
"""
|
| 573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
Args:
|
| 576 |
fields (List[str]): The fields that should be unique in each instance.
|
|
@@ -581,8 +1055,8 @@ class Unique(SingleStreamReducer):
|
|
| 581 |
@staticmethod
|
| 582 |
def to_tuple(instance: dict, fields: List[str]) -> tuple:
|
| 583 |
result = []
|
| 584 |
-
for
|
| 585 |
-
value = instance[
|
| 586 |
if isinstance(value, list):
|
| 587 |
value = tuple(value)
|
| 588 |
result.append(value)
|
|
@@ -598,8 +1072,7 @@ class Unique(SingleStreamReducer):
|
|
| 598 |
|
| 599 |
|
| 600 |
class SplitByValue(MultiStreamOperator):
|
| 601 |
-
"""
|
| 602 |
-
Splits a MultiStream into multiple streams based on unique values in specified fields.
|
| 603 |
|
| 604 |
Args:
|
| 605 |
fields (List[str]): The fields to use when splitting the MultiStream.
|
|
@@ -615,17 +1088,20 @@ class SplitByValue(MultiStreamOperator):
|
|
| 615 |
for stream_name, stream in multi_stream.items():
|
| 616 |
stream_unique_values = uniques[stream_name]
|
| 617 |
for unique_values in stream_unique_values:
|
| 618 |
-
filtering_values =
|
| 619 |
-
filtered_streams = FilterByValues(
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
result[filtered_stream_name] = filtered_streams
|
| 622 |
|
| 623 |
return MultiStream(result)
|
| 624 |
|
| 625 |
|
| 626 |
class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
| 627 |
-
"""
|
| 628 |
-
Applies stream operators to a stream based on specified fields in each instance.
|
| 629 |
|
| 630 |
Args:
|
| 631 |
field (str): The field containing the operators to be applied.
|
|
@@ -635,7 +1111,7 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 635 |
field: str
|
| 636 |
reversed: bool = False
|
| 637 |
|
| 638 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
| 639 |
first_instance = stream.peak()
|
| 640 |
|
| 641 |
operators = first_instance.get(self.field, [])
|
|
@@ -647,16 +1123,67 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 647 |
|
| 648 |
for operator_name in operators:
|
| 649 |
operator = self.get_artifact(operator_name)
|
| 650 |
-
assert isinstance(
|
|
|
|
|
|
|
| 651 |
|
| 652 |
stream = operator(MultiStream({"tmp": stream}))["tmp"]
|
| 653 |
|
| 654 |
yield from stream
|
| 655 |
|
| 656 |
|
| 657 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
"""
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
Args:
|
| 662 |
prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
|
|
@@ -667,13 +1194,17 @@ class AddFieldNamePrefix(StreamInstanceOperator):
|
|
| 667 |
def prepare(self):
|
| 668 |
return super().prepare()
|
| 669 |
|
| 670 |
-
def process(
|
| 671 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
|
| 673 |
|
| 674 |
class MergeStreams(MultiStreamOperator):
|
| 675 |
-
"""
|
| 676 |
-
Merges multiple streams into a single stream.
|
| 677 |
|
| 678 |
Args:
|
| 679 |
new_stream_name (str): The name of the new stream resulting from the merge.
|
|
@@ -681,37 +1212,43 @@ class MergeStreams(MultiStreamOperator):
|
|
| 681 |
origin_stream_name_field_name (str): The field name for the origin stream name.
|
| 682 |
"""
|
| 683 |
|
|
|
|
| 684 |
new_stream_name: str = "all"
|
| 685 |
add_origin_stream_name: bool = True
|
| 686 |
origin_stream_name_field_name: str = "origin"
|
| 687 |
|
| 688 |
def merge(self, multi_stream):
|
| 689 |
for stream_name, stream in multi_stream.items():
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
|
|
|
| 694 |
|
| 695 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
| 696 |
-
return MultiStream(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
|
| 698 |
|
| 699 |
class Shuffle(PagedStreamOperator):
|
| 700 |
-
"""
|
| 701 |
-
Shuffles the order of instances in each page of a stream.
|
| 702 |
|
| 703 |
Args:
|
| 704 |
page_size (int): The size of each page in the stream. Defaults to 1000.
|
| 705 |
"""
|
| 706 |
|
| 707 |
-
def process(self, page: List[Dict], stream_name: str = None) -> Generator:
|
| 708 |
-
|
| 709 |
yield from page
|
| 710 |
|
| 711 |
|
| 712 |
class EncodeLabels(StreamInstanceOperator):
|
| 713 |
-
"""
|
| 714 |
-
Encode labels of specified fields together a into integers.
|
| 715 |
|
| 716 |
Args:
|
| 717 |
fields (List[str]): The fields to encode together.
|
|
@@ -723,16 +1260,20 @@ class EncodeLabels(StreamInstanceOperator):
|
|
| 723 |
self.encoder = {}
|
| 724 |
return super()._process_multi_stream(multi_stream)
|
| 725 |
|
| 726 |
-
def process(
|
| 727 |
-
|
| 728 |
-
|
|
|
|
|
|
|
| 729 |
if not isinstance(values, list):
|
| 730 |
values = [values]
|
| 731 |
for value in values:
|
| 732 |
if value not in self.encoder:
|
| 733 |
self.encoder[value] = len(self.encoder)
|
| 734 |
new_values = [self.encoder[value] for value in values]
|
| 735 |
-
dict_set(
|
|
|
|
|
|
|
| 736 |
|
| 737 |
return instance
|
| 738 |
|
|
@@ -740,7 +1281,7 @@ class EncodeLabels(StreamInstanceOperator):
|
|
| 740 |
class StreamRefiner(SingleStreamOperator):
|
| 741 |
max_instances: int = None
|
| 742 |
|
| 743 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
| 744 |
if self.max_instances is not None:
|
| 745 |
yield from stream.take(self.max_instances)
|
| 746 |
else:
|
|
@@ -748,8 +1289,7 @@ class StreamRefiner(SingleStreamOperator):
|
|
| 748 |
|
| 749 |
|
| 750 |
class DeterministicBalancer(StreamRefiner):
|
| 751 |
-
"""
|
| 752 |
-
A class used to balance streams deterministically.
|
| 753 |
|
| 754 |
Attributes:
|
| 755 |
fields (List[str]): A list of field names to be used in determining the signature of an instance.
|
|
@@ -763,19 +1303,26 @@ class DeterministicBalancer(StreamRefiner):
|
|
| 763 |
fields: List[str]
|
| 764 |
|
| 765 |
def signature(self, instance):
|
| 766 |
-
return str(
|
|
|
|
|
|
|
| 767 |
|
| 768 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
| 769 |
counter = collections.Counter()
|
| 770 |
|
| 771 |
for instance in stream:
|
| 772 |
counter[self.signature(instance)] += 1
|
| 773 |
|
|
|
|
|
|
|
|
|
|
| 774 |
lowest_count = counter.most_common()[-1][-1]
|
| 775 |
|
| 776 |
max_total_instances_per_sign = lowest_count
|
| 777 |
if self.max_instances is not None:
|
| 778 |
-
max_total_instances_per_sign = min(
|
|
|
|
|
|
|
| 779 |
|
| 780 |
counter = collections.Counter()
|
| 781 |
|
|
@@ -791,8 +1338,8 @@ class LengthBalancer(DeterministicBalancer):
|
|
| 791 |
|
| 792 |
def signature(self, instance):
|
| 793 |
total_len = 0
|
| 794 |
-
for
|
| 795 |
-
total_len += len(dict_get(instance,
|
| 796 |
for i, val in enumerate(self.segments_boundaries):
|
| 797 |
if total_len < val:
|
| 798 |
return i
|
|
|
|
| 1 |
import collections
|
| 2 |
import importlib
|
|
|
|
| 3 |
import uuid
|
| 4 |
from abc import abstractmethod
|
| 5 |
+
from collections import Counter
|
| 6 |
from copy import deepcopy
|
| 7 |
from dataclasses import field
|
| 8 |
from itertools import zip_longest
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
from .artifact import Artifact, fetch_artifact
|
| 22 |
+
from .dataclass import NonPositionalField
|
| 23 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 24 |
from .operator import (
|
| 25 |
MultiStream,
|
|
|
|
| 32 |
StreamInstanceOperator,
|
| 33 |
StreamSource,
|
| 34 |
)
|
| 35 |
+
from .random_utils import get_random, nested_seed
|
| 36 |
+
from .stream import Stream
|
| 37 |
from .text_utils import nested_tuple_to_string
|
| 38 |
from .utils import flatten_dict
|
| 39 |
|
| 40 |
|
| 41 |
class FromIterables(StreamInitializerOperator):
|
| 42 |
+
"""Creates a MultiStream from iterables.
|
|
|
|
| 43 |
|
| 44 |
Args:
|
| 45 |
iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
|
|
|
|
| 69 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
| 70 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
| 71 |
that are not present in the mapper are kept as they are.
|
| 72 |
+
process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
|
| 73 |
+
is to be applied to their individual elements. If False, mapping is only applied to a field
|
| 74 |
+
containing a single value.
|
| 75 |
+
|
| 76 |
+
Examples:
|
| 77 |
+
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})
|
| 78 |
+
replaces '1' with 'hi' and '2' with 'bye' in field 'a' in all instances of all streams:
|
| 79 |
+
instance {"a":"1", "b": 2} becomes {"a":"hi", "b": 2}.
|
| 80 |
+
|
| 81 |
+
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_element=True)
|
| 82 |
+
Assuming field 'a' is a list of values, potentially including "1"-s and "2"-s, this replaces
|
| 83 |
+
each such "1" with "hi" and "2" -- with "bye" in all instances of all streams:
|
| 84 |
+
instance {"a": ["1", "2"], "b": 2} becomes {"a": ["hi", "bye"], "b": 2}.
|
| 85 |
+
|
| 86 |
+
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)
|
| 87 |
+
To ensure that all values of field 'a' are mapped in every instance, use strict=True.
|
| 88 |
+
Input instance {"a":"3", "b": 2} will raise an exception per the above call,
|
| 89 |
+
because "3" is not a key in the mapper of "a".
|
| 90 |
"""
|
| 91 |
|
| 92 |
mappers: Dict[str, Dict[str, str]]
|
| 93 |
strict: bool = True
|
| 94 |
+
use_query: bool = False
|
| 95 |
+
process_every_value: bool = False
|
| 96 |
|
| 97 |
def verify(self):
|
| 98 |
# make sure the mappers are valid
|
| 99 |
for key, mapper in self.mappers.items():
|
| 100 |
+
assert isinstance(
|
| 101 |
+
mapper, dict
|
| 102 |
+
), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
|
| 103 |
+
for k in mapper.keys():
|
| 104 |
+
assert isinstance(
|
| 105 |
+
k, str
|
| 106 |
+
), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
|
| 107 |
+
|
| 108 |
+
def process(
|
| 109 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 110 |
+
) -> Dict[str, Any]:
|
| 111 |
for key, mapper in self.mappers.items():
|
| 112 |
value = dict_get(instance, key, use_dpath=self.use_query)
|
| 113 |
if value is not None:
|
| 114 |
+
if (self.process_every_value is True) and (not isinstance(value, list)):
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instace = {instance}"
|
| 117 |
+
)
|
| 118 |
+
if isinstance(value, list):
|
| 119 |
+
if self.process_every_value:
|
| 120 |
+
for i, val in enumerate(value):
|
| 121 |
+
val = str(val) # make sure the value is a string
|
| 122 |
+
if self.strict and (val not in mapper):
|
| 123 |
+
raise KeyError(
|
| 124 |
+
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
| 125 |
+
)
|
| 126 |
+
if val in mapper:
|
| 127 |
+
# replace just that member of value (value is a list)
|
| 128 |
+
value[i] = mapper[val]
|
| 129 |
+
dict_set(instance, key, value, use_dpath=self.use_query)
|
| 130 |
+
else: # field is a list, and process_every_value == False
|
| 131 |
+
if self.strict: # whole lists can not be mapped by a string-to-something mapper
|
| 132 |
+
raise KeyError(
|
| 133 |
+
f"A whole list ({value}) in the instance can not be mapped by a field mapper."
|
| 134 |
+
)
|
| 135 |
+
else: # value is not a list, implying process_every_value == False
|
| 136 |
+
value = str(value) # make sure the value is a string
|
| 137 |
+
if self.strict and (value not in mapper):
|
| 138 |
+
raise KeyError(
|
| 139 |
+
f"value '{value}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
| 140 |
+
)
|
| 141 |
if value in mapper:
|
| 142 |
dict_set(instance, key, mapper[value], use_dpath=self.use_query)
|
| 143 |
+
|
| 144 |
return instance
|
| 145 |
|
| 146 |
|
| 147 |
class FlattenInstances(StreamInstanceOperator):
|
| 148 |
+
"""Flattens each instance in a stream, making nested dictionary entries into top-level entries.
|
|
|
|
| 149 |
|
| 150 |
Args:
|
| 151 |
parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
|
|
|
|
| 155 |
parent_key: str = ""
|
| 156 |
sep: str = "_"
|
| 157 |
|
| 158 |
+
def process(
|
| 159 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 160 |
+
) -> Dict[str, Any]:
|
| 161 |
return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
|
| 162 |
|
| 163 |
|
| 164 |
class AddFields(StreamInstanceOperator):
|
| 165 |
+
"""Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
|
|
|
|
| 166 |
|
| 167 |
Args:
|
| 168 |
fields (Dict[str, object]): The fields to add to each instance.
|
| 169 |
+
use_query (bool) : Use '/' to access inner fields
|
| 170 |
+
use_deepcopy (bool) : Deep copy the input value to avoid later modifications
|
| 171 |
+
|
| 172 |
+
Examples:
|
| 173 |
+
# Add a 'classes' field with a value of a list "positive" and "negative" to all streams
|
| 174 |
+
AddFields(fields={"classes": ["positive","negatives"]})
|
| 175 |
+
|
| 176 |
+
# Add a 'start' field under the 'span' field with a value of 0 to all streams
|
| 177 |
+
AddFields(fields={"span/start": 0}
|
| 178 |
+
|
| 179 |
+
# Add a 'classes' field with a value of a list "positive" and "negative" to 'train' stream
|
| 180 |
+
AddFields(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
|
| 181 |
+
|
| 182 |
+
# Add a 'classes' field on a given list, prevent modification of original list
|
| 183 |
+
# from changing the instance.
|
| 184 |
+
AddFields(fields={"classes": alist}), use_deepcopy=True)
|
| 185 |
"""
|
| 186 |
|
| 187 |
fields: Dict[str, object]
|
| 188 |
use_query: bool = False
|
| 189 |
use_deepcopy: bool = False
|
| 190 |
|
| 191 |
+
def process(
|
| 192 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 193 |
+
) -> Dict[str, Any]:
|
| 194 |
if self.use_query:
|
| 195 |
for key, value in self.fields.items():
|
| 196 |
if self.use_deepcopy:
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
class RemoveFields(StreamInstanceOperator):
|
| 207 |
+
"""Remove specified fields to each instance in a stream.
|
|
|
|
| 208 |
|
| 209 |
Args:
|
| 210 |
+
fields (List[str]): The fields to remove from each instance.
|
| 211 |
"""
|
| 212 |
|
| 213 |
fields: List[str]
|
| 214 |
|
| 215 |
+
def process(
|
| 216 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 217 |
+
) -> Dict[str, Any]:
|
| 218 |
+
for field_name in self.fields:
|
| 219 |
+
del instance[field_name]
|
| 220 |
return instance
|
| 221 |
|
| 222 |
|
| 223 |
class FieldOperator(StreamInstanceOperator):
|
| 224 |
+
"""A general stream that processes the values of a field (or multiple ones.
|
| 225 |
+
|
| 226 |
Args:
|
| 227 |
field (Optional[str]): The field to process, if only a single one is passed Defaults to None
|
| 228 |
to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
|
| 229 |
field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
|
| 230 |
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
|
| 231 |
+
use_query (bool): Whether to use dpath style queries. Defaults to False.
|
| 232 |
"""
|
| 233 |
|
| 234 |
field: Optional[str] = None
|
|
|
|
| 242 |
def verify(self):
|
| 243 |
super().verify()
|
| 244 |
|
| 245 |
+
assert (
|
| 246 |
+
self.field is not None or self.field_to_field is not None
|
| 247 |
+
), "Must supply a field to work on"
|
| 248 |
assert (
|
| 249 |
self.to_field is None or self.field_to_field is None
|
| 250 |
), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
|
| 251 |
assert (
|
| 252 |
self.field is None or self.field_to_field is None
|
| 253 |
), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
|
| 254 |
+
assert (
|
| 255 |
+
self._field_to_field
|
| 256 |
+
), f"the from and to fields must be defined got: {self._field_to_field}"
|
| 257 |
|
| 258 |
@abstractmethod
|
| 259 |
def process_value(self, value: Any) -> Any:
|
|
|
|
| 266 |
self._field_to_field = [(self.field, self.to_field)]
|
| 267 |
else:
|
| 268 |
try:
|
| 269 |
+
self._field_to_field = list(self.field_to_field.items())
|
| 270 |
except AttributeError:
|
| 271 |
self._field_to_field = self.field_to_field
|
| 272 |
|
| 273 |
+
def process(
|
| 274 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 275 |
+
) -> Dict[str, Any]:
|
| 276 |
for from_field, to_field in self._field_to_field:
|
| 277 |
try:
|
| 278 |
old_value = dict_get(
|
|
|
|
| 282 |
default=self.get_default,
|
| 283 |
not_exist_ok=self.not_exist_ok,
|
| 284 |
)
|
| 285 |
+
except Exception as e:
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"Failed to get '{from_field}' from {instance} due to : {e}"
|
| 288 |
+
) from e
|
| 289 |
+
try:
|
| 290 |
+
if self.process_every_value:
|
| 291 |
+
new_value = [self.process_value(value) for value in old_value]
|
| 292 |
+
else:
|
| 293 |
+
new_value = self.process_value(old_value)
|
| 294 |
+
except Exception as e:
|
| 295 |
+
raise ValueError(
|
| 296 |
+
f"Failed to process '{from_field}' from {instance} due to : {e}"
|
| 297 |
+
) from e
|
| 298 |
if self.use_query and is_subpath(from_field, to_field):
|
| 299 |
dict_delete(instance, from_field)
|
| 300 |
+
dict_set(
|
| 301 |
+
instance,
|
| 302 |
+
to_field,
|
| 303 |
+
new_value,
|
| 304 |
+
use_dpath=self.use_query,
|
| 305 |
+
not_exist_ok=True,
|
| 306 |
+
)
|
| 307 |
return instance
|
| 308 |
|
| 309 |
|
| 310 |
class RenameFields(FieldOperator):
|
| 311 |
+
"""Renames fields."""
|
|
|
|
|
|
|
| 312 |
|
| 313 |
def process_value(self, value: Any) -> Any:
|
| 314 |
return value
|
| 315 |
|
| 316 |
+
def process(
|
| 317 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 318 |
+
) -> Dict[str, Any]:
|
| 319 |
res = super().process(instance=instance, stream_name=stream_name)
|
| 320 |
vals = [x[1] for x in self._field_to_field]
|
| 321 |
for key, _ in self._field_to_field:
|
|
|
|
| 327 |
|
| 328 |
|
| 329 |
class AddConstant(FieldOperator):
|
| 330 |
+
"""Adds a value, similar to add + field.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
add: sum to add.
|
| 334 |
"""
|
| 335 |
+
|
| 336 |
+
add: Any
|
| 337 |
+
|
| 338 |
+
def process_value(self, value: Any) -> Any:
|
| 339 |
+
return self.add + value
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class Augmentor(StreamInstanceOperator):
|
| 343 |
+
"""A stream that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
|
| 344 |
+
|
| 345 |
Args:
|
| 346 |
+
augment_model_input: Whether to augment the input to the model.
|
| 347 |
+
augment_task_input: Whether to augment the task input fields. The specific fields are defined in the FormTask operator.
|
| 348 |
+
|
| 349 |
"""
|
| 350 |
|
| 351 |
+
augment_task_input: bool = False
|
| 352 |
+
augment_model_input: bool = False
|
| 353 |
+
|
| 354 |
+
def verify(self):
|
| 355 |
+
assert not (
|
| 356 |
+
self.augment_task_input and self.augment_model_input
|
| 357 |
+
), "Augmentor must set either 'augment_task_input' and 'augment_model_input' but not both"
|
| 358 |
+
assert (
|
| 359 |
+
self.augment_task_input or self.augment_model_input
|
| 360 |
+
), "Augmentor must set either 'augment_task_input' or 'augment_model_input'"
|
| 361 |
+
|
| 362 |
+
super().verify()
|
| 363 |
|
| 364 |
+
@abstractmethod
|
| 365 |
def process_value(self, value: Any) -> Any:
|
| 366 |
+
pass
|
| 367 |
+
|
| 368 |
+
def prepare(self):
|
| 369 |
+
pass
|
| 370 |
|
| 371 |
+
def set_task_input_fields(self, task_input_fields: List[str]):
|
| 372 |
+
self._task_input_fields = [
|
| 373 |
+
"inputs/" + task_input_field for task_input_field in task_input_fields
|
| 374 |
+
]
|
| 375 |
|
| 376 |
+
def process(
|
| 377 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 378 |
+
) -> Dict[str, Any]:
|
| 379 |
+
if self.augment_task_input:
|
| 380 |
+
assert (
|
| 381 |
+
len(self._task_input_fields) > 0
|
| 382 |
+
), "No augmentable input fields were defined in FormTask, and augmentation was requested. Specify the fields to augment in 'argumentable_inputs' attribute of the FormTask."
|
| 383 |
+
fields = self._task_input_fields
|
| 384 |
+
assert not self.augment_model_input
|
| 385 |
+
|
| 386 |
+
if self.augment_model_input:
|
| 387 |
+
fields = ["source"]
|
| 388 |
+
assert not self.augment_task_input
|
| 389 |
+
|
| 390 |
+
for field_name in fields:
|
| 391 |
+
try:
|
| 392 |
+
old_value = dict_get(
|
| 393 |
+
instance,
|
| 394 |
+
field_name,
|
| 395 |
+
use_dpath=True,
|
| 396 |
+
default="",
|
| 397 |
+
not_exist_ok=False,
|
| 398 |
+
)
|
| 399 |
+
except TypeError as e:
|
| 400 |
+
raise TypeError(f"Failed to get {field_name} from {instance}") from e
|
| 401 |
+
|
| 402 |
+
# We are setting a nested seed based on the value processed, to ensure that
|
| 403 |
+
# the augmentation randomizations do not effect other randomization choices and
|
| 404 |
+
# to make the augmentation randomization choices different for each text.
|
| 405 |
+
with nested_seed(str(hash(old_value))):
|
| 406 |
+
try:
|
| 407 |
+
new_value = self.process_value(old_value)
|
| 408 |
+
except Exception as e:
|
| 409 |
+
raise RuntimeError(
|
| 410 |
+
f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
|
| 411 |
+
) from e
|
| 412 |
+
dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
|
| 413 |
+
return instance
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class NullAugmentor(Augmentor):
|
| 417 |
+
def verify(self):
|
| 418 |
+
pass
|
| 419 |
+
|
| 420 |
+
def process_value(self, value: Any) -> Any:
|
| 421 |
+
return value
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class AugmentWhitespace(Augmentor):
|
| 425 |
+
"""Augments the inputs by replace existing whitespace with other whitespace.
|
| 426 |
+
|
| 427 |
+
Currently each whitespace is replaced by a random choice of 1-3 whitespace charaters (spcae, tab, newline).
|
| 428 |
"""
|
| 429 |
+
|
| 430 |
+
def process_value(self, value: Any) -> Any:
|
| 431 |
+
import re
|
| 432 |
+
|
| 433 |
+
words = re.split(r"(\s+)", value)
|
| 434 |
+
new_value = ""
|
| 435 |
+
|
| 436 |
+
for word in words:
|
| 437 |
+
if word.isspace():
|
| 438 |
+
new_value += get_random().choice(
|
| 439 |
+
["\n", "\t", " "]
|
| 440 |
+
) * get_random().randint(1, 3)
|
| 441 |
+
else:
|
| 442 |
+
new_value += word
|
| 443 |
+
return new_value
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class AugmentSuffix(Augmentor):
|
| 447 |
+
r"""Augments the input by appending to it a randomly selected (typically, whitespace) pattern.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
suffixes : the potential (typically, whitespace) patterns to select from.
|
| 451 |
+
The dictionary version allows to specify relative weights of the different patterns.
|
| 452 |
+
remove_existing_trailing_whitespaces : allows to first clean existing trailing whitespaces.
|
| 453 |
+
The selected pattern is then appended to the potentially trimmed at its end input.
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
Examples:
|
| 457 |
+
to append a '\n' or a '\t' to the end of the input, employ
|
| 458 |
+
AugmentSuffix(augment_model_input=True, suffixes=['\n','\t'])
|
| 459 |
+
If '\n' is preferred over '\t', at 2:1 ratio, employ
|
| 460 |
+
AugmentSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1})
|
| 461 |
+
which will append '\n' twice as often as '\t'.
|
| 462 |
+
|
| 463 |
"""
|
| 464 |
|
| 465 |
+
suffixes: Optional[Union[List[str], Dict[str, int]]] = [" ", "\n", "\t"]
|
| 466 |
+
remove_existing_trailing_whitespaces: Optional[bool] = False
|
| 467 |
+
|
| 468 |
+
def verify(self):
|
| 469 |
+
assert (
|
| 470 |
+
isinstance(self.suffixes, list) or isinstance(self.suffixes, dict)
|
| 471 |
+
), f"Argument 'suffixes' should be either a list or a dictionary, whereas it is of type {type(self.suffixes)}"
|
| 472 |
+
|
| 473 |
+
if isinstance(self.suffixes, dict):
|
| 474 |
+
for k, v in self.suffixes.items():
|
| 475 |
+
assert isinstance(
|
| 476 |
+
k, str
|
| 477 |
+
), f"suffixes should map strings, whereas key {k!s} is of type {type(k)}"
|
| 478 |
+
assert isinstance(
|
| 479 |
+
v, int
|
| 480 |
+
), f"suffixes should map to ints, whereas value {v!s} is of type {type(v)}"
|
| 481 |
+
else:
|
| 482 |
+
for k in self.suffixes:
|
| 483 |
+
assert isinstance(
|
| 484 |
+
k, str
|
| 485 |
+
), f"suffixes should be a list of strings, whereas member {k!s} is of type {type(k)}"
|
| 486 |
+
|
| 487 |
+
self.pats = (
|
| 488 |
+
self.suffixes
|
| 489 |
+
if isinstance(self.suffixes, list)
|
| 490 |
+
else [k for k, v in self.suffixes.items()]
|
| 491 |
+
)
|
| 492 |
+
total_weight = (
|
| 493 |
+
len(self.pats)
|
| 494 |
+
if isinstance(self.suffixes, list)
|
| 495 |
+
else sum([v for k, v in self.suffixes.items()])
|
| 496 |
+
)
|
| 497 |
+
self.weights = (
|
| 498 |
+
[1.0 / total_weight] * len(self.pats)
|
| 499 |
+
if isinstance(self.suffixes, list)
|
| 500 |
+
else [float(self.suffixes[p]) / total_weight for p in self.pats]
|
| 501 |
+
)
|
| 502 |
+
super().verify()
|
| 503 |
+
|
| 504 |
+
def process_value(self, value: Any) -> Any:
|
| 505 |
+
assert value is not None, "input value should not be None"
|
| 506 |
+
new_value = str(value)
|
| 507 |
+
if self.remove_existing_trailing_whitespaces:
|
| 508 |
+
new_value = new_value.rstrip()
|
| 509 |
+
new_value += get_random().choices(self.pats, self.weights, k=1)[0]
|
| 510 |
+
|
| 511 |
+
return new_value
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class ShuffleFieldValues(FieldOperator):
|
| 515 |
+
"""Shuffles an iterable value."""
|
| 516 |
+
|
| 517 |
def process_value(self, value: Any) -> Any:
|
| 518 |
res = list(value)
|
| 519 |
+
get_random().shuffle(res)
|
| 520 |
return res
|
| 521 |
|
| 522 |
|
| 523 |
class JoinStr(FieldOperator):
|
| 524 |
+
"""Joins a list of strings (contents of a field), similar to str.join().
|
| 525 |
+
|
| 526 |
Args:
|
| 527 |
separator (str): text to put between values
|
| 528 |
"""
|
|
|
|
| 534 |
|
| 535 |
|
| 536 |
class Apply(StreamInstanceOperator):
|
| 537 |
+
"""A class used to apply a python function and store the result in a field.
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
function (str): name of function.
|
| 541 |
+
to_field (str): the field to store the result
|
| 542 |
+
additional arguments are field names passed to the function
|
| 543 |
+
|
| 544 |
+
Examples:
|
| 545 |
+
Store in field "b" the uppercase string of the value in field "a"
|
| 546 |
+
Apply("a", function=str.upper, to_field="b")
|
| 547 |
+
|
| 548 |
+
Dump the json representation of field "t" and store back in the same field.
|
| 549 |
+
Apply("t", function=json.dumps, to_field="t")
|
| 550 |
+
|
| 551 |
+
Set the time in a field 'b'.
|
| 552 |
+
Apply(function=time.time, to_field="b")
|
| 553 |
+
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
__allow_unexpected_arguments__ = True
|
| 557 |
function: Callable = NonPositionalField(required=True)
|
| 558 |
to_field: str = NonPositionalField(required=True)
|
|
|
|
| 567 |
else:
|
| 568 |
parts.append(function.__name__)
|
| 569 |
|
| 570 |
+
return ".".join(parts)
|
|
|
|
|
|
|
| 571 |
|
| 572 |
def str_to_function(self, function_str: str) -> Callable:
|
| 573 |
splitted = function_str.split(".", 1)
|
| 574 |
if len(splitted) == 1:
|
| 575 |
+
return __builtins__[splitted[0]]
|
| 576 |
+
|
| 577 |
+
module_name, function_name = splitted
|
| 578 |
+
if module_name in __builtins__:
|
| 579 |
+
obj = __builtins__[module_name]
|
| 580 |
+
elif module_name in globals():
|
| 581 |
+
obj = globals()[module_name]
|
| 582 |
else:
|
| 583 |
+
obj = importlib.import_module(module_name)
|
| 584 |
+
for part in function_name.split("."):
|
| 585 |
+
obj = getattr(obj, part)
|
| 586 |
+
return obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
def prepare(self):
|
| 589 |
super().prepare()
|
|
|
|
| 591 |
self.function = self.str_to_function(self.function)
|
| 592 |
self._init_dict["function"] = self.function_to_str(self.function)
|
| 593 |
|
| 594 |
+
def process(
|
| 595 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 596 |
+
) -> Dict[str, Any]:
|
| 597 |
argv = [instance[arg] for arg in self._argv]
|
| 598 |
kwargs = {key: instance[val] for key, val in self._kwargs}
|
| 599 |
|
|
|
|
| 604 |
|
| 605 |
|
| 606 |
class ListFieldValues(StreamInstanceOperator):
|
| 607 |
+
"""Concatenates values of multiple fields into a list, and assigns it to a new field."""
|
|
|
|
|
|
|
| 608 |
|
| 609 |
+
fields: List[str]
|
| 610 |
to_field: str
|
| 611 |
use_query: bool = False
|
| 612 |
|
| 613 |
+
def process(
|
| 614 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 615 |
+
) -> Dict[str, Any]:
|
| 616 |
values = []
|
| 617 |
+
for field_name in self.fields:
|
| 618 |
+
values.append(dict_get(instance, field_name, use_dpath=self.use_query))
|
| 619 |
instance[self.to_field] = values
|
| 620 |
return instance
|
| 621 |
|
| 622 |
|
| 623 |
class ZipFieldValues(StreamInstanceOperator):
|
| 624 |
+
"""Zips values of multiple fields similar to list(zip(*fields))."""
|
|
|
|
|
|
|
| 625 |
|
| 626 |
fields: str
|
| 627 |
to_field: str
|
| 628 |
longest: bool = False
|
| 629 |
use_query: bool = False
|
| 630 |
|
| 631 |
+
def process(
|
| 632 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 633 |
+
) -> Dict[str, Any]:
|
| 634 |
values = []
|
| 635 |
+
for field_name in self.fields:
|
| 636 |
+
values.append(dict_get(instance, field_name, use_dpath=self.use_query))
|
| 637 |
if self.longest:
|
| 638 |
zipped = zip_longest(*values)
|
| 639 |
else:
|
|
|
|
| 643 |
|
| 644 |
|
| 645 |
class IndexOf(StreamInstanceOperator):
|
| 646 |
+
"""Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)."""
|
|
|
|
|
|
|
| 647 |
|
| 648 |
search_in: str
|
| 649 |
index_of: str
|
| 650 |
to_field: str
|
| 651 |
use_query: bool = False
|
| 652 |
|
| 653 |
+
def process(
|
| 654 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 655 |
+
) -> Dict[str, Any]:
|
| 656 |
lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
|
| 657 |
item = dict_get(instance, self.index_of, use_dpath=self.use_query)
|
| 658 |
instance[self.to_field] = lst.index(item)
|
|
|
|
| 660 |
|
| 661 |
|
| 662 |
class TakeByField(StreamInstanceOperator):
|
| 663 |
+
"""Takes value from one field based on another field similar to field[index]."""
|
|
|
|
|
|
|
| 664 |
|
| 665 |
field: str
|
| 666 |
index: str
|
|
|
|
| 671 |
if self.to_field is None:
|
| 672 |
self.to_field = self.field
|
| 673 |
|
| 674 |
+
def process(
|
| 675 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 676 |
+
) -> Dict[str, Any]:
|
| 677 |
value = dict_get(instance, self.field, use_dpath=self.use_query)
|
| 678 |
index_value = dict_get(instance, self.index, use_dpath=self.use_query)
|
| 679 |
instance[self.to_field] = value[index_value]
|
|
|
|
| 681 |
|
| 682 |
|
| 683 |
class CopyFields(FieldOperator):
|
| 684 |
+
"""Copies specified fields from one field to another.
|
|
|
|
| 685 |
|
| 686 |
Args:
|
| 687 |
field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
|
|
|
|
| 695 |
class AddID(StreamInstanceOperator):
|
| 696 |
id_field_name: str = "id"
|
| 697 |
|
| 698 |
+
def process(
|
| 699 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 700 |
+
) -> Dict[str, Any]:
|
| 701 |
instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
|
| 702 |
return instance
|
| 703 |
|
| 704 |
|
| 705 |
class CastFields(StreamInstanceOperator):
|
| 706 |
+
"""Casts specified fields to specified types.
|
|
|
|
| 707 |
|
| 708 |
Args:
|
| 709 |
types (Dict[str, str]): A dictionary mapping fields to their new types.
|
|
|
|
| 726 |
def _cast_single(self, value, type, field):
|
| 727 |
try:
|
| 728 |
return self.types[type](value)
|
| 729 |
+
except Exception as e:
|
| 730 |
if field not in self.failure_defaults:
|
| 731 |
raise ValueError(
|
| 732 |
f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
|
| 733 |
+
) from e
|
| 734 |
return self.failure_defaults[field]
|
| 735 |
|
| 736 |
def _cast_multiple(self, values, type, field):
|
| 737 |
values = [self._cast_single(value, type, field) for value in values]
|
| 738 |
|
| 739 |
+
def process(
|
| 740 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 741 |
+
) -> Dict[str, Any]:
|
| 742 |
+
for field_name, type in self.fields.items():
|
| 743 |
+
value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
|
| 744 |
if self.cast_multiple:
|
| 745 |
+
casted_value = self._cast_multiple(value, type, field_name)
|
| 746 |
else:
|
| 747 |
+
casted_value = self._cast_single(value, type, field_name)
|
| 748 |
+
dict_set(
|
| 749 |
+
instance, field_name, casted_value, use_dpath=self.use_nested_query
|
| 750 |
+
)
|
| 751 |
return instance
|
| 752 |
|
| 753 |
|
|
|
|
| 770 |
strict: bool = False
|
| 771 |
recursive: bool = True
|
| 772 |
|
| 773 |
+
def process(
|
| 774 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 775 |
+
) -> Dict[str, Any]:
|
| 776 |
return recursive_divide(instance, self.divisor, strict=self.strict)
|
| 777 |
|
| 778 |
|
| 779 |
class ArtifactFetcherMixin:
|
| 780 |
+
"""Provides a way to fetch and cache artifacts in the system.
|
|
|
|
| 781 |
|
| 782 |
Args:
|
| 783 |
cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
|
|
|
|
| 794 |
|
| 795 |
|
| 796 |
class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
| 797 |
+
"""Applies value operators to each instance in a stream based on specified fields.
|
|
|
|
| 798 |
|
| 799 |
Args:
|
| 800 |
value_field (str): The field containing the value to be operated on.
|
|
|
|
| 808 |
default_operators: List[str] = None
|
| 809 |
fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
|
| 810 |
|
| 811 |
+
def process(
|
| 812 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 813 |
+
) -> Dict[str, Any]:
|
| 814 |
operator_names = instance.get(self.operators_field)
|
| 815 |
if operator_names is None:
|
| 816 |
assert (
|
|
|
|
| 823 |
|
| 824 |
for name in operator_names:
|
| 825 |
operator = self.get_artifact(name)
|
| 826 |
+
for field_name in self.inputs_fields:
|
| 827 |
+
value = instance[field_name]
|
| 828 |
+
if field_name in self.fields_to_treat_as_list:
|
| 829 |
+
instance[field_name] = [operator.process(v) for v in value]
|
| 830 |
else:
|
| 831 |
+
instance[field_name] = operator.process(instance[field_name])
|
| 832 |
|
| 833 |
return instance
|
| 834 |
|
| 835 |
|
| 836 |
class FilterByValues(SingleStreamOperator):
|
| 837 |
+
"""Filters a stream, yielding only instances that match specified values in the provided fields.
|
| 838 |
+
|
| 839 |
+
Args:
|
| 840 |
+
values (Dict[str, Any]): For each field, the values that instances should match to be included in the output.
|
| 841 |
"""
|
| 842 |
+
|
| 843 |
+
required_values: Dict[str, Any]
|
| 844 |
+
|
| 845 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 846 |
+
for instance in stream:
|
| 847 |
+
filter = False
|
| 848 |
+
for key, value in self.required_values.items():
|
| 849 |
+
if key not in instance:
|
| 850 |
+
raise ValueError(
|
| 851 |
+
f"Required filter field ('{key}') in FilterByValues is not found in {instance}"
|
| 852 |
+
)
|
| 853 |
+
if instance[key] != value:
|
| 854 |
+
filter = True
|
| 855 |
+
if not filter:
|
| 856 |
+
yield instance
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class ExtractFieldValues(MultiStreamOperator):
|
| 860 |
+
field: str
|
| 861 |
+
stream_name: str
|
| 862 |
+
overall_top_frequency_percent: Optional[int] = 100
|
| 863 |
+
min_frequency_percent: Optional[int] = 0
|
| 864 |
+
to_field: str
|
| 865 |
+
process_every_value: Optional[bool] = False
|
| 866 |
+
|
| 867 |
+
"""
|
| 868 |
+
Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
|
| 869 |
+
as a list in a new field ('to_field') in all streams.
|
| 870 |
+
|
| 871 |
+
More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
|
| 872 |
+
When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
|
| 873 |
+
the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
|
| 874 |
+
When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
|
| 875 |
+
less than 'min_frequency_percent' of the total number of instances in the stream.
|
| 876 |
+
At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
|
| 877 |
+
|
| 878 |
+
Examples:
|
| 879 |
+
|
| 880 |
+
ExtractFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
|
| 881 |
+
field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
|
| 882 |
+
every instance in all streams.
|
| 883 |
+
|
| 884 |
+
ExtractFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
|
| 885 |
+
in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
|
| 886 |
+
value members in these lists, and report the most frequent values.
|
| 887 |
+
if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
|
| 888 |
+
'to_field' of each instance of all streams.
|
| 889 |
+
|
| 890 |
+
ExtractFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
|
| 891 |
+
extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
|
| 892 |
+
and stores them in field 'classes' of each instance of all streams.
|
| 893 |
+
|
| 894 |
+
ExtractFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
|
| 895 |
+
extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
|
| 896 |
+
Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
|
| 897 |
+
"""
|
| 898 |
+
|
| 899 |
+
def verify(self):
|
| 900 |
+
assert (
|
| 901 |
+
self.overall_top_frequency_percent <= 100
|
| 902 |
+
and self.overall_top_frequency_percent >= 0
|
| 903 |
+
), "'overall_top_frequency_percent' must be between 0 and 100"
|
| 904 |
+
assert (
|
| 905 |
+
self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
|
| 906 |
+
), "'min_frequency_percent' must be between 0 and 100"
|
| 907 |
+
assert not (
|
| 908 |
+
self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
|
| 909 |
+
), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
|
| 910 |
+
super().verify()
|
| 911 |
+
|
| 912 |
+
def process(self, multi_stream: MultiStream) -> MultiStream:
|
| 913 |
+
stream = multi_stream[self.stream_name]
|
| 914 |
+
all_values = []
|
| 915 |
+
for instance in stream:
|
| 916 |
+
if (not isinstance(instance[self.field], list)) and (
|
| 917 |
+
self.process_every_value is True
|
| 918 |
+
):
|
| 919 |
+
raise ValueError(
|
| 920 |
+
"'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
|
| 921 |
+
)
|
| 922 |
+
if (not isinstance(instance[self.field], list)) or (
|
| 923 |
+
self.process_every_value is False
|
| 924 |
+
):
|
| 925 |
+
# either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
|
| 926 |
+
all_values.append(
|
| 927 |
+
(*instance[self.field],)
|
| 928 |
+
if isinstance(instance[self.field], list)
|
| 929 |
+
else instance[self.field]
|
| 930 |
+
) # convert to a tuple if list, to enable the use of Counter which would not accept
|
| 931 |
+
# a list as an entity to count its occurrences
|
| 932 |
+
else:
|
| 933 |
+
# content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
|
| 934 |
+
all_values.extend(instance[self.field])
|
| 935 |
+
counter = Counter(
|
| 936 |
+
all_values
|
| 937 |
+
) # here all_values is a list of individual values, or tupples. Hence, Counter is feasible
|
| 938 |
+
values_and_counts = counter.most_common()
|
| 939 |
+
if self.overall_top_frequency_percent < 100:
|
| 940 |
+
top_frequency = len(all_values) * self.overall_top_frequency_percent / 100.0
|
| 941 |
+
sum_counts = 0
|
| 942 |
+
for _i, p in enumerate(values_and_counts):
|
| 943 |
+
sum_counts += p[1]
|
| 944 |
+
if sum_counts >= top_frequency:
|
| 945 |
+
break
|
| 946 |
+
values_and_counts = counter.most_common(_i + 1)
|
| 947 |
+
if self.min_frequency_percent > 0:
|
| 948 |
+
min_frequency = self.min_frequency_percent * len(all_values) / 100.0
|
| 949 |
+
while values_and_counts[-1][1] < min_frequency:
|
| 950 |
+
values_and_counts.pop()
|
| 951 |
+
values_to_keep = [
|
| 952 |
+
[*ele[0]] if isinstance(ele[0], tuple) else ele[0]
|
| 953 |
+
for ele in values_and_counts
|
| 954 |
+
]
|
| 955 |
+
for name in multi_stream:
|
| 956 |
+
for instance in multi_stream[name]:
|
| 957 |
+
instance[self.to_field] = values_to_keep
|
| 958 |
+
return multi_stream
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
class FilterByListsOfValues(SingleStreamOperator):
|
| 962 |
+
"""Filters a stream, yielding only instances that whose field values are included in the specified value lists.
|
| 963 |
|
| 964 |
Args:
|
| 965 |
+
required_values (Dict[str, List]): For each field, the list of values that instances should match to be included in the output.
|
| 966 |
"""
|
| 967 |
|
| 968 |
+
required_values: Dict[str, List]
|
| 969 |
+
|
| 970 |
+
def verify(self):
|
| 971 |
+
super().verify()
|
| 972 |
+
for key, value in self.required_values.items():
|
| 973 |
+
if not isinstance(value, list):
|
| 974 |
+
raise ValueError(
|
| 975 |
+
f"The filter for key ('{key}') in FilterByListsOfValues is not a list but '{value}'"
|
| 976 |
+
)
|
| 977 |
|
| 978 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 979 |
for instance in stream:
|
| 980 |
+
filter = False
|
| 981 |
+
for key, value in self.required_values.items():
|
| 982 |
+
if key not in instance:
|
| 983 |
+
raise ValueError(
|
| 984 |
+
f"Required filter field ('{key}') in FilterByListsOfValues is not found in {instance}"
|
| 985 |
+
)
|
| 986 |
+
if instance[key] not in value:
|
| 987 |
+
filter = True
|
| 988 |
+
if not filter:
|
| 989 |
yield instance
|
| 990 |
|
| 991 |
|
| 992 |
+
class Intersect(FieldOperator):
|
| 993 |
+
"""Intersects the value of a field, which must be a list, with a given list.
|
| 994 |
+
|
| 995 |
+
Args:
|
| 996 |
+
allowed_values (list) - list to intersect.
|
| 997 |
+
"""
|
| 998 |
+
|
| 999 |
+
allowed_values: List[Any]
|
| 1000 |
+
|
| 1001 |
+
def verify(self):
|
| 1002 |
+
super().verify()
|
| 1003 |
+
if self.process_every_value:
|
| 1004 |
+
raise ValueError(
|
| 1005 |
+
"'process_every_value=True' is not supported in Intersect operator"
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
if not isinstance(self.allowed_values, list):
|
| 1009 |
+
raise ValueError(
|
| 1010 |
+
f"The allowed_values is not a list but '{self.allowed_values}'"
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
def process_value(self, value: Any) -> Any:
|
| 1014 |
+
if not isinstance(value, list):
|
| 1015 |
+
raise ValueError(f"The value in field is not a list but '{value}'")
|
| 1016 |
+
return [e for e in value if e in self.allowed_values]
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
class RemoveValues(FieldOperator):
|
| 1020 |
+
"""Removes elements in a field, which must be a list, using a given list of unallowed.
|
| 1021 |
+
|
| 1022 |
+
Args:
|
| 1023 |
+
unallowed_values (list) - removed_values.
|
| 1024 |
"""
|
| 1025 |
+
|
| 1026 |
+
unallowed_values: List[Any]
|
| 1027 |
+
|
| 1028 |
+
def verify(self):
|
| 1029 |
+
super().verify()
|
| 1030 |
+
if self.process_every_value:
|
| 1031 |
+
raise ValueError(
|
| 1032 |
+
"'process_every_value=True' is not supported in RemoveValues operator"
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
if not isinstance(self.unallowed_values, list):
|
| 1036 |
+
raise ValueError(
|
| 1037 |
+
f"The unallowed_values is not a list but '{self.unallowed_values}'"
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
def process_value(self, value: Any) -> Any:
|
| 1041 |
+
if not isinstance(value, list):
|
| 1042 |
+
raise ValueError(f"The value in field is not a list but '{value}'")
|
| 1043 |
+
return [e for e in value if e not in self.unallowed_values]
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
class Unique(SingleStreamReducer):
|
| 1047 |
+
"""Reduces a stream to unique instances based on specified fields.
|
| 1048 |
|
| 1049 |
Args:
|
| 1050 |
fields (List[str]): The fields that should be unique in each instance.
|
|
|
|
| 1055 |
@staticmethod
|
| 1056 |
def to_tuple(instance: dict, fields: List[str]) -> tuple:
|
| 1057 |
result = []
|
| 1058 |
+
for field_name in fields:
|
| 1059 |
+
value = instance[field_name]
|
| 1060 |
if isinstance(value, list):
|
| 1061 |
value = tuple(value)
|
| 1062 |
result.append(value)
|
|
|
|
| 1072 |
|
| 1073 |
|
| 1074 |
class SplitByValue(MultiStreamOperator):
|
| 1075 |
+
"""Splits a MultiStream into multiple streams based on unique values in specified fields.
|
|
|
|
| 1076 |
|
| 1077 |
Args:
|
| 1078 |
fields (List[str]): The fields to use when splitting the MultiStream.
|
|
|
|
| 1088 |
for stream_name, stream in multi_stream.items():
|
| 1089 |
stream_unique_values = uniques[stream_name]
|
| 1090 |
for unique_values in stream_unique_values:
|
| 1091 |
+
filtering_values = dict(zip(self.fields, unique_values))
|
| 1092 |
+
filtered_streams = FilterByValues(
|
| 1093 |
+
required_values=filtering_values
|
| 1094 |
+
)._process_single_stream(stream)
|
| 1095 |
+
filtered_stream_name = (
|
| 1096 |
+
stream_name + "_" + nested_tuple_to_string(unique_values)
|
| 1097 |
+
)
|
| 1098 |
result[filtered_stream_name] = filtered_streams
|
| 1099 |
|
| 1100 |
return MultiStream(result)
|
| 1101 |
|
| 1102 |
|
| 1103 |
class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
| 1104 |
+
"""Applies stream operators to a stream based on specified fields in each instance.
|
|
|
|
| 1105 |
|
| 1106 |
Args:
|
| 1107 |
field (str): The field containing the operators to be applied.
|
|
|
|
| 1111 |
field: str
|
| 1112 |
reversed: bool = False
|
| 1113 |
|
| 1114 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1115 |
first_instance = stream.peak()
|
| 1116 |
|
| 1117 |
operators = first_instance.get(self.field, [])
|
|
|
|
| 1123 |
|
| 1124 |
for operator_name in operators:
|
| 1125 |
operator = self.get_artifact(operator_name)
|
| 1126 |
+
assert isinstance(
|
| 1127 |
+
operator, StreamingOperator
|
| 1128 |
+
), f"Operator {operator_name} must be a SingleStreamOperator"
|
| 1129 |
|
| 1130 |
stream = operator(MultiStream({"tmp": stream}))["tmp"]
|
| 1131 |
|
| 1132 |
yield from stream
|
| 1133 |
|
| 1134 |
|
| 1135 |
+
class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
|
| 1136 |
+
"""Applies metric operators to a stream based on a metric field specified in each instance.
|
| 1137 |
+
|
| 1138 |
+
Args:
|
| 1139 |
+
metric_field (str): The field containing the metrics to be applied.
|
| 1140 |
+
calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
|
| 1141 |
"""
|
| 1142 |
+
|
| 1143 |
+
metric_field: str
|
| 1144 |
+
calc_confidence_intervals: bool
|
| 1145 |
+
|
| 1146 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1147 |
+
from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
|
| 1148 |
+
|
| 1149 |
+
first_instance = stream.peak()
|
| 1150 |
+
|
| 1151 |
+
metric_names = first_instance.get(self.metric_field, [])
|
| 1152 |
+
if not metric_names:
|
| 1153 |
+
raise RuntimeError(
|
| 1154 |
+
f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
if isinstance(metric_names, str):
|
| 1158 |
+
metric_names = [metric_names]
|
| 1159 |
+
|
| 1160 |
+
# Each metric operator computes its score and then sets the main score, overwriting
|
| 1161 |
+
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
| 1162 |
+
# This will cause the first listed metric to run last, and the main score will be set
|
| 1163 |
+
# by the first listed metric (as desired).
|
| 1164 |
+
metric_names = list(reversed(metric_names))
|
| 1165 |
+
|
| 1166 |
+
for metric_name in metric_names:
|
| 1167 |
+
metric = self.get_artifact(metric_name)
|
| 1168 |
+
assert isinstance(
|
| 1169 |
+
metric, Metric
|
| 1170 |
+
), f"Operator {metric_name} must be a Metric"
|
| 1171 |
+
|
| 1172 |
+
if not self.calc_confidence_intervals:
|
| 1173 |
+
if isinstance(metric, MetricWithConfidenceInterval):
|
| 1174 |
+
metric.disable_confidence_interval_calculation()
|
| 1175 |
+
elif isinstance(metric, MetricPipeline) and isinstance(
|
| 1176 |
+
metric.metric, MetricWithConfidenceInterval
|
| 1177 |
+
):
|
| 1178 |
+
metric.metric.disable_confidence_interval_calculation()
|
| 1179 |
+
|
| 1180 |
+
stream = metric(MultiStream({"tmp": stream}))["tmp"]
|
| 1181 |
+
|
| 1182 |
+
yield from stream
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
class AddFieldNamePrefix(StreamInstanceOperator):
|
| 1186 |
+
"""Adds a prefix to each field name in each instance of a stream.
|
| 1187 |
|
| 1188 |
Args:
|
| 1189 |
prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
|
|
|
|
| 1194 |
def prepare(self):
|
| 1195 |
return super().prepare()
|
| 1196 |
|
| 1197 |
+
def process(
|
| 1198 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 1199 |
+
) -> Dict[str, Any]:
|
| 1200 |
+
return {
|
| 1201 |
+
self.prefix_dict[stream_name] + key: value
|
| 1202 |
+
for key, value in instance.items()
|
| 1203 |
+
}
|
| 1204 |
|
| 1205 |
|
| 1206 |
class MergeStreams(MultiStreamOperator):
|
| 1207 |
+
"""Merges multiple streams into a single stream.
|
|
|
|
| 1208 |
|
| 1209 |
Args:
|
| 1210 |
new_stream_name (str): The name of the new stream resulting from the merge.
|
|
|
|
| 1212 |
origin_stream_name_field_name (str): The field name for the origin stream name.
|
| 1213 |
"""
|
| 1214 |
|
| 1215 |
+
streams_to_merge: List[str] = None
|
| 1216 |
new_stream_name: str = "all"
|
| 1217 |
add_origin_stream_name: bool = True
|
| 1218 |
origin_stream_name_field_name: str = "origin"
|
| 1219 |
|
| 1220 |
def merge(self, multi_stream):
|
| 1221 |
for stream_name, stream in multi_stream.items():
|
| 1222 |
+
if self.streams_to_merge is None or stream_name in self.streams_to_merge:
|
| 1223 |
+
for instance in stream:
|
| 1224 |
+
if self.add_origin_stream_name:
|
| 1225 |
+
instance[self.origin_stream_name_field_name] = stream_name
|
| 1226 |
+
yield instance
|
| 1227 |
|
| 1228 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
| 1229 |
+
return MultiStream(
|
| 1230 |
+
{
|
| 1231 |
+
self.new_stream_name: Stream(
|
| 1232 |
+
self.merge, gen_kwargs={"multi_stream": multi_stream}
|
| 1233 |
+
)
|
| 1234 |
+
}
|
| 1235 |
+
)
|
| 1236 |
|
| 1237 |
|
| 1238 |
class Shuffle(PagedStreamOperator):
|
| 1239 |
+
"""Shuffles the order of instances in each page of a stream.
|
|
|
|
| 1240 |
|
| 1241 |
Args:
|
| 1242 |
page_size (int): The size of each page in the stream. Defaults to 1000.
|
| 1243 |
"""
|
| 1244 |
|
| 1245 |
+
def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
|
| 1246 |
+
get_random().shuffle(page)
|
| 1247 |
yield from page
|
| 1248 |
|
| 1249 |
|
| 1250 |
class EncodeLabels(StreamInstanceOperator):
|
| 1251 |
+
"""Encode labels of specified fields together a into integers.
|
|
|
|
| 1252 |
|
| 1253 |
Args:
|
| 1254 |
fields (List[str]): The fields to encode together.
|
|
|
|
| 1260 |
self.encoder = {}
|
| 1261 |
return super()._process_multi_stream(multi_stream)
|
| 1262 |
|
| 1263 |
+
def process(
|
| 1264 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 1265 |
+
) -> Dict[str, Any]:
|
| 1266 |
+
for field_name in self.fields:
|
| 1267 |
+
values = dict_get(instance, field_name, use_dpath=True)
|
| 1268 |
if not isinstance(values, list):
|
| 1269 |
values = [values]
|
| 1270 |
for value in values:
|
| 1271 |
if value not in self.encoder:
|
| 1272 |
self.encoder[value] = len(self.encoder)
|
| 1273 |
new_values = [self.encoder[value] for value in values]
|
| 1274 |
+
dict_set(
|
| 1275 |
+
instance, field_name, new_values, use_dpath=True, set_multiple=True
|
| 1276 |
+
)
|
| 1277 |
|
| 1278 |
return instance
|
| 1279 |
|
|
|
|
| 1281 |
class StreamRefiner(SingleStreamOperator):
|
| 1282 |
max_instances: int = None
|
| 1283 |
|
| 1284 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1285 |
if self.max_instances is not None:
|
| 1286 |
yield from stream.take(self.max_instances)
|
| 1287 |
else:
|
|
|
|
| 1289 |
|
| 1290 |
|
| 1291 |
class DeterministicBalancer(StreamRefiner):
|
| 1292 |
+
"""A class used to balance streams deterministically.
|
|
|
|
| 1293 |
|
| 1294 |
Attributes:
|
| 1295 |
fields (List[str]): A list of field names to be used in determining the signature of an instance.
|
|
|
|
| 1303 |
fields: List[str]
|
| 1304 |
|
| 1305 |
def signature(self, instance):
|
| 1306 |
+
return str(
|
| 1307 |
+
tuple(dict_get(instance, field, use_dpath=True) for field in self.fields)
|
| 1308 |
+
)
|
| 1309 |
|
| 1310 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1311 |
counter = collections.Counter()
|
| 1312 |
|
| 1313 |
for instance in stream:
|
| 1314 |
counter[self.signature(instance)] += 1
|
| 1315 |
|
| 1316 |
+
if len(counter) == 0:
|
| 1317 |
+
return
|
| 1318 |
+
|
| 1319 |
lowest_count = counter.most_common()[-1][-1]
|
| 1320 |
|
| 1321 |
max_total_instances_per_sign = lowest_count
|
| 1322 |
if self.max_instances is not None:
|
| 1323 |
+
max_total_instances_per_sign = min(
|
| 1324 |
+
lowest_count, self.max_instances // len(counter)
|
| 1325 |
+
)
|
| 1326 |
|
| 1327 |
counter = collections.Counter()
|
| 1328 |
|
|
|
|
| 1338 |
|
| 1339 |
def signature(self, instance):
|
| 1340 |
total_len = 0
|
| 1341 |
+
for field_name in self.fields:
|
| 1342 |
+
total_len += len(dict_get(instance, field_name, use_dpath=True))
|
| 1343 |
for i, val in enumerate(self.segments_boundaries):
|
| 1344 |
if total_len < val:
|
| 1345 |
return i
|