File size: 2,057 Bytes
dbaa71b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import List, Optional, Dict, Any

from obsei.payload import TextPayload
from obsei.postprocessor.base_postprocessor import (
    BasePostprocessorConfig,
    BasePostprocessor
)
from obsei.postprocessor.inference_aggregator_function import BaseInferenceAggregateFunction
from obsei.preprocessor.text_splitter import TextSplitterPayload


class InferenceAggregatorConfig(BasePostprocessorConfig):
    aggregate_function: BaseInferenceAggregateFunction


class InferenceAggregator(BasePostprocessor):
    def postprocess_input(  # type: ignore[override]
        self, input_list: List[TextPayload], config: InferenceAggregatorConfig, **kwargs: Any
    ) -> List[TextPayload]:

        aggregated_payloads = self.segregate_payload(input_list)
        postproces_output: List[TextPayload] = []
        for key, payload_list in aggregated_payloads.items():
            postproces_output.extend(
                config.aggregate_function.execute(payload_list)
            )

        return postproces_output

    @staticmethod
    def segregate_payload(
        input_list: List[TextPayload],
    ) -> Dict[str, List[TextPayload]]:
        segregated_payload: Dict[str, List[TextPayload]] = {}

        # segregate payload
        for idx, payload in enumerate(input_list):
            splitter_data: Optional[TextSplitterPayload] = (
                payload.meta.get("splitter", None) if payload.meta else None
            )
            doc_id = splitter_data.document_id if splitter_data else str(idx)
            if doc_id not in segregated_payload:
                segregated_payload[doc_id] = []
            segregated_payload[doc_id].append(payload)

        # sort based on chunk id
        for doc_id, payloads in segregated_payload.items():
            if (
                len(payloads) > 0
                and payloads[0].meta
                and payloads[0].meta.get("splitter", None)
            ):
                payloads.sort(key=lambda x: x.meta["splitter"].chunk_id)  # type: ignore[no-any-return]

        return segregated_payload