File size: 3,711 Bytes
a8f310f
4868000
f8c1b22
4868000
 
 
 
 
 
 
92d8b2d
 
 
 
 
 
 
 
 
4868000
 
 
a8f310f
4868000
a8f310f
 
 
 
0f18707
a8f310f
 
 
 
8d5bd0c
 
 
 
a8f310f
 
 
8d5bd0c
 
 
 
a8f310f
8d5bd0c
a8f310f
8d5bd0c
4868000
 
 
 
 
8d5bd0c
 
 
 
 
 
 
 
 
a8f310f
 
 
8d5bd0c
 
a8f310f
8d5bd0c
 
a8f310f
8d5bd0c
 
 
 
 
 
 
 
 
 
 
a8f310f
 
 
8d5bd0c
a8f310f
 
 
8d5bd0c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import Any, Dict, List, Optional

from .operator import StreamInstanceOperator


class Tasker:
    pass


class FormTask(Tasker, StreamInstanceOperator):
    """FormTask packs the different instance fields into dictionaries by their roles in the task.

    The output instance contains three fields:
        "inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
        "outputs" -- for the fields listed in Arg "outputs".
        "metrics" -- to contain the value of Arg 'metrics'

    """

    inputs: List[str]
    outputs: List[str]
    metrics: List[str]
    augmentable_inputs: List[str] = []

    def verify(self):
        for augmentable_input in self.augmentable_inputs:
            assert (
                augmentable_input in self.inputs
            ), f"augmentable_input {augmentable_input} is not part of {self.inputs}"

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        try:
            inputs = {key: instance[key] for key in self.inputs}
        except KeyError as e:
            raise KeyError(
                f"Unexpected FormTask input column names ({[key for key in self.inputs if key not in instance]})."
                f"The available input names: {list(instance.keys())}"
            ) from e
        try:
            outputs = {key: instance[key] for key in self.outputs}
        except KeyError as e:
            raise KeyError(
                f"Unexpected FormTask output column names: {[key for key in self.outputs if key not in instance]}"
                f" \n available names:{list(instance.keys())}\n given output names:{self.outputs}"
            ) from e

        return {
            "inputs": inputs,
            "outputs": outputs,
            "metrics": self.metrics,
        }


class MultipleChoiceTask(FormTask):
    choices_field: str = "choices"
    choices_separator: str = "\n"
    enumeration_suffix: str = ". "
    use_text_in_target: bool = False
    alphabet: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

    def process_single_choice(
        self, choice: str, index: int, use_text: bool = True
    ) -> str:
        try:
            processed_choice = f"{self.alphabet[index]}"
        except IndexError as e:
            raise ValueError(
                f"Too many choices, the length of alphabet '{self.alphabet}': {len(self.alphabet)} is the limit"
            ) from e
        if use_text:
            processed_choice += f"{self.enumeration_suffix}{choice}"
        return processed_choice

    def process_choices(self, choices: List[str]) -> str:
        processed_choices = []
        for index, choice in enumerate(choices):
            processed_choices.append(self.process_single_choice(choice, index))
        return self.choices_separator.join(processed_choices)

    def process_target(self, choices, target_index):
        return self.process_single_choice(
            choices[target_index], target_index, use_text=self.use_text_in_target
        )

    def process(
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
    ) -> Dict[str, Any]:
        result = super().process(instance, stream_name)
        target_key, target_value = next(iter(result["outputs"].items()))
        choices = result["inputs"][self.choices_field]
        target_index_in_choices = choices.index(target_value)

        processed_choices = self.process_choices(choices)
        processed_target = self.process_target(choices, target_index_in_choices)

        result["inputs"][self.choices_field] = processed_choices
        result["outputs"][target_key] = processed_target

        return result