File size: 1,775 Bytes
dbaa71b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import List, Optional, Any

from pandas import DataFrame

from obsei.source.base_source import BaseSource, BaseSourceConfig
from obsei.payload import TextPayload


class PandasSourceConfig(BaseSourceConfig):
    TYPE: str = "Pandas"

    dataframe: DataFrame
    text_columns: List[str]
    separator: str = " "
    include_columns: Optional[List[str]] = None

    def __init__(self, **data: Any):
        super().__init__(**data)

        if len(self.text_columns) == 0:
            raise ValueError("`text_columns` cannot be empty")

        if not all(
            [text_column in self.dataframe.columns for text_column in self.text_columns]
        ):
            raise ValueError("Every `text_columns` should be present in `dataframe`")

        try:
            self.dataframe[self.text_columns] = self.dataframe[
                self.text_columns
            ].astype("string")
        except TypeError as e:
            raise ValueError("Unable to convert `text_columns` to string dtype")


class PandasSource(BaseSource):
    NAME: Optional[str] = "Pandas"

    def lookup(self, config: PandasSourceConfig, **kwargs: Any) -> List[TextPayload]:  # type: ignore[override]
        df_to_records = config.dataframe.to_dict("records")
        source_responses: List[TextPayload] = [
            TextPayload(
                processed_text=config.separator.join(
                    [record.get(text_column) for text_column in config.text_columns]
                ),
                meta={key: record[key] for key in config.include_columns}
                if config.include_columns is not None
                else record,
                source_name=self.NAME,
            )
            for record in df_to_records
        ]

        return source_responses