File size: 5,673 Bytes
dbaa71b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from datetime import datetime
from typing import Any, Dict, List, Optional

from praw import Reddit
from pydantic import Field, PrivateAttr, SecretStr
from pydantic_settings import BaseSettings

from obsei.payload import TextPayload
from obsei.misc.utils import (
    DATETIME_STRING_PATTERN,
    DEFAULT_LOOKUP_PERIOD,
    convert_utc_time,
    text_from_html,
)
from obsei.source.base_source import BaseSource, BaseSourceConfig


class RedditCredInfo(BaseSettings):
    # Create credential at https://www.reddit.com/prefs/apps
    # Also refer https://praw.readthedocs.io/en/latest/getting_started/authentication.html
    # Currently Password Flow, Read Only Mode and Saved Refresh Token Mode are supported
    client_id: SecretStr = Field(None, env="reddit_client_id")
    client_secret: SecretStr = Field(None, env="reddit_client_secret")
    user_agent: str = "Test User Agent"
    redirect_uri: Optional[str] = None
    refresh_token: Optional[SecretStr] = Field(None, env="reddit_refresh_token")
    username: Optional[str] = Field(None, env="reddit_username")
    password: Optional[SecretStr] = Field(None, env="reddit_password")
    read_only: bool = True


class RedditConfig(BaseSourceConfig):
    # This is done to avoid exposing member to API response
    _reddit_client: Reddit = PrivateAttr()
    TYPE: str = "Reddit"
    subreddits: List[str]
    post_ids: Optional[List[str]] = None
    lookup_period: Optional[str] = None
    include_post_meta: Optional[bool] = True
    post_meta_field: str = "post_meta"
    cred_info: Optional[RedditCredInfo] = Field(None)

    def __init__(self, **data: Any):
        super().__init__(**data)

        self.cred_info = self.cred_info or RedditCredInfo()

        self._reddit_client = Reddit(
            client_id=self.cred_info.client_id.get_secret_value(),
            client_secret=self.cred_info.client_secret.get_secret_value(),
            redirect_uri=self.cred_info.redirect_uri,
            user_agent=self.cred_info.user_agent,
            refresh_token=self.cred_info.refresh_token.get_secret_value()
            if self.cred_info.refresh_token
            else None,
            username=self.cred_info.username if self.cred_info.username else None,
            password=self.cred_info.password.get_secret_value()
            if self.cred_info.password
            else None,
        )

        self._reddit_client.read_only = self.cred_info.read_only

    def get_reddit_client(self) -> Reddit:
        return self._reddit_client


class RedditSource(BaseSource):
    NAME: str = "Reddit"

    def lookup(self, config: RedditConfig, **kwargs: Any) -> List[TextPayload]:  # type: ignore[override]
        source_responses: List[TextPayload] = []

        # Get data from state
        id: str = kwargs.get("id", None)
        state: Optional[Dict[str, Any]] = (
            None
            if id is None or self.store is None
            else self.store.get_source_state(id)
        )
        update_state: bool = True if id else False
        state = state or dict()

        subreddit_reference = config.get_reddit_client().subreddit(
            "+".join(config.subreddits)
        )
        post_stream = subreddit_reference.stream.submissions(pause_after=-1)
        for post in post_stream:
            if post is None:
                break

            post_data = vars(post)
            post_id = post_data["id"]
            if config.post_ids and not config.post_ids.__contains__(post_id):
                continue

            post_stat: Dict[str, Any] = state.get(post_id, dict())
            lookup_period: str = post_stat.get("since_time", config.lookup_period)
            lookup_period = lookup_period or DEFAULT_LOOKUP_PERIOD
            if len(lookup_period) <= 5:
                since_time = convert_utc_time(lookup_period)
            else:
                since_time = datetime.strptime(lookup_period, DATETIME_STRING_PATTERN)

            last_since_time: datetime = since_time

            since_id: Optional[str] = post_stat.get("since_comment_id", None)
            last_index = since_id
            state[post_id] = post_stat

            post.comment_sort = "new"
            post.comments.replace_more(limit=None)

            # top_level_comments only
            first_comment = True
            for comment in post.comments:
                comment_data = vars(comment)
                if config.include_post_meta:
                    comment_data[config.post_meta_field] = post_data

                comment_time = datetime.utcfromtimestamp(
                    int(comment_data["created_utc"])
                )
                comment_id = comment_data["id"]

                if comment_time < since_time:
                    break
                if last_index and last_index == comment_id:
                    break
                if last_since_time is None or last_since_time < comment_time:
                    last_since_time = comment_time
                if last_index is None or first_comment:
                    last_index = comment_id
                    first_comment = False

                text = "".join(text_from_html(comment_data["body_html"]))

                source_responses.append(
                    TextPayload(
                        processed_text=text, meta=comment_data, source_name=self.NAME
                    )
                )

            post_stat["since_time"] = last_since_time.strftime(DATETIME_STRING_PATTERN)
            post_stat["since_comment_id"] = last_index

        if update_state and self.store is not None:
            self.store.update_source_state(workflow_id=id, state=state)

        return source_responses