Spaces:
Sleeping
Sleeping
File size: 2,057 Bytes
dbaa71b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from typing import List, Optional, Dict, Any
from obsei.payload import TextPayload
from obsei.postprocessor.base_postprocessor import (
BasePostprocessorConfig,
BasePostprocessor
)
from obsei.postprocessor.inference_aggregator_function import BaseInferenceAggregateFunction
from obsei.preprocessor.text_splitter import TextSplitterPayload
class InferenceAggregatorConfig(BasePostprocessorConfig):
aggregate_function: BaseInferenceAggregateFunction
class InferenceAggregator(BasePostprocessor):
def postprocess_input( # type: ignore[override]
self, input_list: List[TextPayload], config: InferenceAggregatorConfig, **kwargs: Any
) -> List[TextPayload]:
aggregated_payloads = self.segregate_payload(input_list)
postproces_output: List[TextPayload] = []
for key, payload_list in aggregated_payloads.items():
postproces_output.extend(
config.aggregate_function.execute(payload_list)
)
return postproces_output
@staticmethod
def segregate_payload(
input_list: List[TextPayload],
) -> Dict[str, List[TextPayload]]:
segregated_payload: Dict[str, List[TextPayload]] = {}
# segregate payload
for idx, payload in enumerate(input_list):
splitter_data: Optional[TextSplitterPayload] = (
payload.meta.get("splitter", None) if payload.meta else None
)
doc_id = splitter_data.document_id if splitter_data else str(idx)
if doc_id not in segregated_payload:
segregated_payload[doc_id] = []
segregated_payload[doc_id].append(payload)
# sort based on chunk id
for doc_id, payloads in segregated_payload.items():
if (
len(payloads) > 0
and payloads[0].meta
and payloads[0].meta.get("splitter", None)
):
payloads.sort(key=lambda x: x.meta["splitter"].chunk_id) # type: ignore[no-any-return]
return segregated_payload
|