File size: 3,711 Bytes
a8f310f 4868000 f8c1b22 4868000 92d8b2d 4868000 a8f310f 4868000 a8f310f 0f18707 a8f310f 8d5bd0c a8f310f 8d5bd0c a8f310f 8d5bd0c a8f310f 8d5bd0c 4868000 8d5bd0c a8f310f 8d5bd0c a8f310f 8d5bd0c a8f310f 8d5bd0c a8f310f 8d5bd0c a8f310f 8d5bd0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from typing import Any, Dict, List, Optional
from .operator import StreamInstanceOperator
class Tasker:
pass
class FormTask(Tasker, StreamInstanceOperator):
"""FormTask packs the different instance fields into dictionaries by their roles in the task.
The output instance contains three fields:
"inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
"outputs" -- for the fields listed in Arg "outputs".
"metrics" -- to contain the value of Arg 'metrics'
"""
inputs: List[str]
outputs: List[str]
metrics: List[str]
augmentable_inputs: List[str] = []
def verify(self):
for augmentable_input in self.augmentable_inputs:
assert (
augmentable_input in self.inputs
), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
try:
inputs = {key: instance[key] for key in self.inputs}
except KeyError as e:
raise KeyError(
f"Unexpected FormTask input column names ({[key for key in self.inputs if key not in instance]})."
f"The available input names: {list(instance.keys())}"
) from e
try:
outputs = {key: instance[key] for key in self.outputs}
except KeyError as e:
raise KeyError(
f"Unexpected FormTask output column names: {[key for key in self.outputs if key not in instance]}"
f" \n available names:{list(instance.keys())}\n given output names:{self.outputs}"
) from e
return {
"inputs": inputs,
"outputs": outputs,
"metrics": self.metrics,
}
class MultipleChoiceTask(FormTask):
choices_field: str = "choices"
choices_separator: str = "\n"
enumeration_suffix: str = ". "
use_text_in_target: bool = False
alphabet: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
def process_single_choice(
self, choice: str, index: int, use_text: bool = True
) -> str:
try:
processed_choice = f"{self.alphabet[index]}"
except IndexError as e:
raise ValueError(
f"Too many choices, the length of alphabet '{self.alphabet}': {len(self.alphabet)} is the limit"
) from e
if use_text:
processed_choice += f"{self.enumeration_suffix}{choice}"
return processed_choice
def process_choices(self, choices: List[str]) -> str:
processed_choices = []
for index, choice in enumerate(choices):
processed_choices.append(self.process_single_choice(choice, index))
return self.choices_separator.join(processed_choices)
def process_target(self, choices, target_index):
return self.process_single_choice(
choices[target_index], target_index, use_text=self.use_text_in_target
)
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
result = super().process(instance, stream_name)
target_key, target_value = next(iter(result["outputs"].items()))
choices = result["inputs"][self.choices_field]
target_index_in_choices = choices.index(target_value)
processed_choices = self.process_choices(choices)
processed_target = self.process_target(choices, target_index_in_choices)
result["inputs"][self.choices_field] = processed_choices
result["outputs"][target_key] = processed_target
return result
|