Spaces:
Sleeping
Sleeping
File size: 5,673 Bytes
dbaa71b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from datetime import datetime
from typing import Any, Dict, List, Optional
from praw import Reddit
from pydantic import Field, PrivateAttr, SecretStr
from pydantic_settings import BaseSettings
from obsei.payload import TextPayload
from obsei.misc.utils import (
DATETIME_STRING_PATTERN,
DEFAULT_LOOKUP_PERIOD,
convert_utc_time,
text_from_html,
)
from obsei.source.base_source import BaseSource, BaseSourceConfig
class RedditCredInfo(BaseSettings):
# Create credential at https://www.reddit.com/prefs/apps
# Also refer https://praw.readthedocs.io/en/latest/getting_started/authentication.html
# Currently Password Flow, Read Only Mode and Saved Refresh Token Mode are supported
client_id: SecretStr = Field(None, env="reddit_client_id")
client_secret: SecretStr = Field(None, env="reddit_client_secret")
user_agent: str = "Test User Agent"
redirect_uri: Optional[str] = None
refresh_token: Optional[SecretStr] = Field(None, env="reddit_refresh_token")
username: Optional[str] = Field(None, env="reddit_username")
password: Optional[SecretStr] = Field(None, env="reddit_password")
read_only: bool = True
class RedditConfig(BaseSourceConfig):
# This is done to avoid exposing member to API response
_reddit_client: Reddit = PrivateAttr()
TYPE: str = "Reddit"
subreddits: List[str]
post_ids: Optional[List[str]] = None
lookup_period: Optional[str] = None
include_post_meta: Optional[bool] = True
post_meta_field: str = "post_meta"
cred_info: Optional[RedditCredInfo] = Field(None)
def __init__(self, **data: Any):
super().__init__(**data)
self.cred_info = self.cred_info or RedditCredInfo()
self._reddit_client = Reddit(
client_id=self.cred_info.client_id.get_secret_value(),
client_secret=self.cred_info.client_secret.get_secret_value(),
redirect_uri=self.cred_info.redirect_uri,
user_agent=self.cred_info.user_agent,
refresh_token=self.cred_info.refresh_token.get_secret_value()
if self.cred_info.refresh_token
else None,
username=self.cred_info.username if self.cred_info.username else None,
password=self.cred_info.password.get_secret_value()
if self.cred_info.password
else None,
)
self._reddit_client.read_only = self.cred_info.read_only
def get_reddit_client(self) -> Reddit:
return self._reddit_client
class RedditSource(BaseSource):
NAME: str = "Reddit"
def lookup(self, config: RedditConfig, **kwargs: Any) -> List[TextPayload]: # type: ignore[override]
source_responses: List[TextPayload] = []
# Get data from state
id: str = kwargs.get("id", None)
state: Optional[Dict[str, Any]] = (
None
if id is None or self.store is None
else self.store.get_source_state(id)
)
update_state: bool = True if id else False
state = state or dict()
subreddit_reference = config.get_reddit_client().subreddit(
"+".join(config.subreddits)
)
post_stream = subreddit_reference.stream.submissions(pause_after=-1)
for post in post_stream:
if post is None:
break
post_data = vars(post)
post_id = post_data["id"]
if config.post_ids and not config.post_ids.__contains__(post_id):
continue
post_stat: Dict[str, Any] = state.get(post_id, dict())
lookup_period: str = post_stat.get("since_time", config.lookup_period)
lookup_period = lookup_period or DEFAULT_LOOKUP_PERIOD
if len(lookup_period) <= 5:
since_time = convert_utc_time(lookup_period)
else:
since_time = datetime.strptime(lookup_period, DATETIME_STRING_PATTERN)
last_since_time: datetime = since_time
since_id: Optional[str] = post_stat.get("since_comment_id", None)
last_index = since_id
state[post_id] = post_stat
post.comment_sort = "new"
post.comments.replace_more(limit=None)
# top_level_comments only
first_comment = True
for comment in post.comments:
comment_data = vars(comment)
if config.include_post_meta:
comment_data[config.post_meta_field] = post_data
comment_time = datetime.utcfromtimestamp(
int(comment_data["created_utc"])
)
comment_id = comment_data["id"]
if comment_time < since_time:
break
if last_index and last_index == comment_id:
break
if last_since_time is None or last_since_time < comment_time:
last_since_time = comment_time
if last_index is None or first_comment:
last_index = comment_id
first_comment = False
text = "".join(text_from_html(comment_data["body_html"]))
source_responses.append(
TextPayload(
processed_text=text, meta=comment_data, source_name=self.NAME
)
)
post_stat["since_time"] = last_since_time.strftime(DATETIME_STRING_PATTERN)
post_stat["since_comment_id"] = last_index
if update_state and self.store is not None:
self.store.update_source_state(workflow_id=id, state=state)
return source_responses
|