demo_obsei / obsei_module /obsei /source /reddit_scrapper.py
kltn20133118's picture
Upload 337 files
dbaa71b verified
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
import mmh3
from pydantic import PrivateAttr
from reddit_rss_reader.reader import RedditContent, RedditRSSReader
from obsei.source.base_source import BaseSource, BaseSourceConfig
from obsei.payload import TextPayload
from obsei.misc.utils import (
DATETIME_STRING_PATTERN,
DEFAULT_LOOKUP_PERIOD,
convert_utc_time,
)
logger = logging.getLogger(__name__)
class RedditScrapperConfig(BaseSourceConfig):
_scrapper: RedditRSSReader = PrivateAttr()
TYPE: str = "RedditScrapper"
url: str
url_id: Optional[str] = None
user_agent: Optional[str] = None
lookup_period: Optional[str] = None
def __init__(self, **data: Any):
super().__init__(**data)
# Using 32 bit hash
self.url_id = self.url_id or "{:02x}".format(mmh3.hash(self.url, signed=False))
self._scrapper = RedditRSSReader(
url=self.url,
user_agent=self.user_agent
if self.user_agent
else "script {url_hash}".format(url_hash=self.url_id),
)
def get_readers(self) -> RedditRSSReader:
return self._scrapper
class RedditScrapperSource(BaseSource):
NAME: Optional[str] = "RedditScrapper"
def lookup(self, config: RedditScrapperConfig, **kwargs: Any) -> List[TextPayload]: # type: ignore[override]
source_responses: List[TextPayload] = []
# Get data from state
identifier: str = kwargs.get("id", None)
state: Optional[Dict[str, Any]] = (
None
if identifier is None or self.store is None
else self.store.get_source_state(identifier)
)
update_state: bool = True if identifier else False
state = state or dict()
scrapper_stat: Dict[str, Any] = (
dict() if not config.url_id else state.get(config.url_id, dict())
)
lookup_period: str = scrapper_stat.get("since_time", config.lookup_period)
lookup_period = lookup_period or DEFAULT_LOOKUP_PERIOD
since_time = convert_utc_time(lookup_period)
last_since_time: datetime = since_time
since_id: Optional[str] = scrapper_stat.get("since_id", None)
last_index = since_id
if config.url_id:
state[config.url_id] = scrapper_stat
reddit_data: Optional[List[RedditContent]] = None
try:
reddit_data = config.get_readers().fetch_content(
after=since_time, since_id=since_id
)
except RuntimeError as ex:
logger.warning(ex.__cause__)
reddit_data = reddit_data or []
for reddit in reddit_data:
source_responses.append(
TextPayload(
processed_text=f"{reddit.title}. {reddit.extracted_text}",
meta=reddit.__dict__,
source_name=self.NAME,
)
)
comment_time = reddit.updated.replace(tzinfo=timezone.utc)
if last_since_time is None or last_since_time < comment_time:
last_since_time = comment_time
if last_index is None:
# Assuming list is sorted based on time
last_index = reddit.id
scrapper_stat["since_time"] = last_since_time.strftime(DATETIME_STRING_PATTERN)
scrapper_stat["since_id"] = last_index
if update_state and self.store is not None:
self.store.update_source_state(workflow_id=identifier, state=state)
return source_responses