Spaces:
Sleeping
Sleeping
from datetime import datetime | |
from typing import Any, Dict, List, Optional | |
from praw import Reddit | |
from pydantic import Field, PrivateAttr, SecretStr | |
from pydantic_settings import BaseSettings | |
from obsei.payload import TextPayload | |
from obsei.misc.utils import ( | |
DATETIME_STRING_PATTERN, | |
DEFAULT_LOOKUP_PERIOD, | |
convert_utc_time, | |
text_from_html, | |
) | |
from obsei.source.base_source import BaseSource, BaseSourceConfig | |
class RedditCredInfo(BaseSettings): | |
# Create credential at https://www.reddit.com/prefs/apps | |
# Also refer https://praw.readthedocs.io/en/latest/getting_started/authentication.html | |
# Currently Password Flow, Read Only Mode and Saved Refresh Token Mode are supported | |
client_id: SecretStr = Field(None, env="reddit_client_id") | |
client_secret: SecretStr = Field(None, env="reddit_client_secret") | |
user_agent: str = "Test User Agent" | |
redirect_uri: Optional[str] = None | |
refresh_token: Optional[SecretStr] = Field(None, env="reddit_refresh_token") | |
username: Optional[str] = Field(None, env="reddit_username") | |
password: Optional[SecretStr] = Field(None, env="reddit_password") | |
read_only: bool = True | |
class RedditConfig(BaseSourceConfig): | |
# This is done to avoid exposing member to API response | |
_reddit_client: Reddit = PrivateAttr() | |
TYPE: str = "Reddit" | |
subreddits: List[str] | |
post_ids: Optional[List[str]] = None | |
lookup_period: Optional[str] = None | |
include_post_meta: Optional[bool] = True | |
post_meta_field: str = "post_meta" | |
cred_info: Optional[RedditCredInfo] = Field(None) | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
self.cred_info = self.cred_info or RedditCredInfo() | |
self._reddit_client = Reddit( | |
client_id=self.cred_info.client_id.get_secret_value(), | |
client_secret=self.cred_info.client_secret.get_secret_value(), | |
redirect_uri=self.cred_info.redirect_uri, | |
user_agent=self.cred_info.user_agent, | |
refresh_token=self.cred_info.refresh_token.get_secret_value() | |
if self.cred_info.refresh_token | |
else None, | |
username=self.cred_info.username if self.cred_info.username else None, | |
password=self.cred_info.password.get_secret_value() | |
if self.cred_info.password | |
else None, | |
) | |
self._reddit_client.read_only = self.cred_info.read_only | |
def get_reddit_client(self) -> Reddit: | |
return self._reddit_client | |
class RedditSource(BaseSource): | |
NAME: str = "Reddit" | |
def lookup(self, config: RedditConfig, **kwargs: Any) -> List[TextPayload]: # type: ignore[override] | |
source_responses: List[TextPayload] = [] | |
# Get data from state | |
id: str = kwargs.get("id", None) | |
state: Optional[Dict[str, Any]] = ( | |
None | |
if id is None or self.store is None | |
else self.store.get_source_state(id) | |
) | |
update_state: bool = True if id else False | |
state = state or dict() | |
subreddit_reference = config.get_reddit_client().subreddit( | |
"+".join(config.subreddits) | |
) | |
post_stream = subreddit_reference.stream.submissions(pause_after=-1) | |
for post in post_stream: | |
if post is None: | |
break | |
post_data = vars(post) | |
post_id = post_data["id"] | |
if config.post_ids and not config.post_ids.__contains__(post_id): | |
continue | |
post_stat: Dict[str, Any] = state.get(post_id, dict()) | |
lookup_period: str = post_stat.get("since_time", config.lookup_period) | |
lookup_period = lookup_period or DEFAULT_LOOKUP_PERIOD | |
if len(lookup_period) <= 5: | |
since_time = convert_utc_time(lookup_period) | |
else: | |
since_time = datetime.strptime(lookup_period, DATETIME_STRING_PATTERN) | |
last_since_time: datetime = since_time | |
since_id: Optional[str] = post_stat.get("since_comment_id", None) | |
last_index = since_id | |
state[post_id] = post_stat | |
post.comment_sort = "new" | |
post.comments.replace_more(limit=None) | |
# top_level_comments only | |
first_comment = True | |
for comment in post.comments: | |
comment_data = vars(comment) | |
if config.include_post_meta: | |
comment_data[config.post_meta_field] = post_data | |
comment_time = datetime.utcfromtimestamp( | |
int(comment_data["created_utc"]) | |
) | |
comment_id = comment_data["id"] | |
if comment_time < since_time: | |
break | |
if last_index and last_index == comment_id: | |
break | |
if last_since_time is None or last_since_time < comment_time: | |
last_since_time = comment_time | |
if last_index is None or first_comment: | |
last_index = comment_id | |
first_comment = False | |
text = "".join(text_from_html(comment_data["body_html"])) | |
source_responses.append( | |
TextPayload( | |
processed_text=text, meta=comment_data, source_name=self.NAME | |
) | |
) | |
post_stat["since_time"] = last_since_time.strftime(DATETIME_STRING_PATTERN) | |
post_stat["since_comment_id"] = last_index | |
if update_state and self.store is not None: | |
self.store.update_source_state(workflow_id=id, state=state) | |
return source_responses | |