kltn20133118's picture
Upload 337 files
dbaa71b verified
from datetime import datetime
from typing import Any, Dict, List, Optional
from praw import Reddit
from pydantic import Field, PrivateAttr, SecretStr
from pydantic_settings import BaseSettings
from obsei.payload import TextPayload
from obsei.misc.utils import (
DATETIME_STRING_PATTERN,
DEFAULT_LOOKUP_PERIOD,
convert_utc_time,
text_from_html,
)
from obsei.source.base_source import BaseSource, BaseSourceConfig
class RedditCredInfo(BaseSettings):
# Create credential at https://www.reddit.com/prefs/apps
# Also refer https://praw.readthedocs.io/en/latest/getting_started/authentication.html
# Currently Password Flow, Read Only Mode and Saved Refresh Token Mode are supported
client_id: SecretStr = Field(None, env="reddit_client_id")
client_secret: SecretStr = Field(None, env="reddit_client_secret")
user_agent: str = "Test User Agent"
redirect_uri: Optional[str] = None
refresh_token: Optional[SecretStr] = Field(None, env="reddit_refresh_token")
username: Optional[str] = Field(None, env="reddit_username")
password: Optional[SecretStr] = Field(None, env="reddit_password")
read_only: bool = True
class RedditConfig(BaseSourceConfig):
# This is done to avoid exposing member to API response
_reddit_client: Reddit = PrivateAttr()
TYPE: str = "Reddit"
subreddits: List[str]
post_ids: Optional[List[str]] = None
lookup_period: Optional[str] = None
include_post_meta: Optional[bool] = True
post_meta_field: str = "post_meta"
cred_info: Optional[RedditCredInfo] = Field(None)
def __init__(self, **data: Any):
super().__init__(**data)
self.cred_info = self.cred_info or RedditCredInfo()
self._reddit_client = Reddit(
client_id=self.cred_info.client_id.get_secret_value(),
client_secret=self.cred_info.client_secret.get_secret_value(),
redirect_uri=self.cred_info.redirect_uri,
user_agent=self.cred_info.user_agent,
refresh_token=self.cred_info.refresh_token.get_secret_value()
if self.cred_info.refresh_token
else None,
username=self.cred_info.username if self.cred_info.username else None,
password=self.cred_info.password.get_secret_value()
if self.cred_info.password
else None,
)
self._reddit_client.read_only = self.cred_info.read_only
def get_reddit_client(self) -> Reddit:
return self._reddit_client
class RedditSource(BaseSource):
NAME: str = "Reddit"
def lookup(self, config: RedditConfig, **kwargs: Any) -> List[TextPayload]: # type: ignore[override]
source_responses: List[TextPayload] = []
# Get data from state
id: str = kwargs.get("id", None)
state: Optional[Dict[str, Any]] = (
None
if id is None or self.store is None
else self.store.get_source_state(id)
)
update_state: bool = True if id else False
state = state or dict()
subreddit_reference = config.get_reddit_client().subreddit(
"+".join(config.subreddits)
)
post_stream = subreddit_reference.stream.submissions(pause_after=-1)
for post in post_stream:
if post is None:
break
post_data = vars(post)
post_id = post_data["id"]
if config.post_ids and not config.post_ids.__contains__(post_id):
continue
post_stat: Dict[str, Any] = state.get(post_id, dict())
lookup_period: str = post_stat.get("since_time", config.lookup_period)
lookup_period = lookup_period or DEFAULT_LOOKUP_PERIOD
if len(lookup_period) <= 5:
since_time = convert_utc_time(lookup_period)
else:
since_time = datetime.strptime(lookup_period, DATETIME_STRING_PATTERN)
last_since_time: datetime = since_time
since_id: Optional[str] = post_stat.get("since_comment_id", None)
last_index = since_id
state[post_id] = post_stat
post.comment_sort = "new"
post.comments.replace_more(limit=None)
# top_level_comments only
first_comment = True
for comment in post.comments:
comment_data = vars(comment)
if config.include_post_meta:
comment_data[config.post_meta_field] = post_data
comment_time = datetime.utcfromtimestamp(
int(comment_data["created_utc"])
)
comment_id = comment_data["id"]
if comment_time < since_time:
break
if last_index and last_index == comment_id:
break
if last_since_time is None or last_since_time < comment_time:
last_since_time = comment_time
if last_index is None or first_comment:
last_index = comment_id
first_comment = False
text = "".join(text_from_html(comment_data["body_html"]))
source_responses.append(
TextPayload(
processed_text=text, meta=comment_data, source_name=self.NAME
)
)
post_stat["since_time"] = last_since_time.strftime(DATETIME_STRING_PATTERN)
post_stat["since_comment_id"] = last_index
if update_state and self.store is not None:
self.store.update_source_state(workflow_id=id, state=state)
return source_responses