Spaces:
Sleeping
Sleeping
import logging | |
from datetime import datetime | |
import pytz | |
import requests | |
from typing import Any, Dict, List, Optional | |
from pydantic import Field | |
from pydantic.types import SecretStr | |
from pydantic_settings import BaseSettings | |
from searchtweets import collect_results, gen_request_parameters | |
from obsei.source.base_source import BaseSource, BaseSourceConfig | |
from obsei.payload import TextPayload | |
from obsei.misc.utils import convert_utc_time | |
logger = logging.getLogger(__name__) | |
TWITTER_OAUTH_ENDPOINT = "https://api.twitter.com/oauth2/token" | |
DEFAULT_MAX_TWEETS = 10 | |
DEFAULT_TWEET_FIELDS = [ | |
"author_id", | |
"conversation_id", | |
"created_at", | |
"entities", | |
"geo", | |
"id", | |
"in_reply_to_user_id", | |
"lang", | |
"public_metrics", | |
"referenced_tweets", | |
"source", | |
"text", | |
"withheld", | |
] | |
DEFAULT_EXPANSIONS = [ | |
"author_id", | |
"entities.mentions.username", | |
"geo.place_id", | |
"in_reply_to_user_id", | |
"referenced_tweets.id", | |
"referenced_tweets.id.author_id", | |
] | |
DEFAULT_PLACE_FIELDS = [ | |
"contained_within", | |
"country", | |
"country_code", | |
"full_name", | |
"geo", | |
"id", | |
"name", | |
"place_type", | |
] | |
DEFAULT_USER_FIELDS = [ | |
"created_at", | |
"description", | |
"entities", | |
"id", | |
"location", | |
"name", | |
"public_metrics", | |
"url", | |
"username", | |
"verified", | |
] | |
DEFAULT_OPERATORS = ["-is:reply", "-is:retweet"] | |
class TwitterCredentials(BaseSettings): | |
bearer_token: SecretStr = Field("", env="twitter_bearer_token") | |
consumer_key: SecretStr = Field("", env="twitter_consumer_key") | |
consumer_secret: SecretStr = Field("", env="twitter_consumer_secret") | |
endpoint: str = Field( | |
"https://api.twitter.com/2/tweets/search/recent", env="twitter_endpoint" | |
) | |
extra_headers_dict: Optional[Dict[str, Any]] = None | |
class TwitterSourceConfig(BaseSourceConfig): | |
TYPE: str = "Twitter" | |
query: Optional[str] = None | |
keywords: Optional[List[str]] = None | |
hashtags: Optional[List[str]] = None | |
usernames: Optional[List[str]] = None | |
operators: Optional[List[str]] = Field(DEFAULT_OPERATORS) | |
since_id: Optional[int] = None | |
until_id: Optional[int] = None | |
lookup_period: Optional[str] = None | |
tweet_fields: Optional[List[str]] = Field(DEFAULT_TWEET_FIELDS) | |
user_fields: Optional[List[str]] = Field(DEFAULT_USER_FIELDS) | |
expansions: Optional[List[str]] = Field(DEFAULT_EXPANSIONS) | |
place_fields: Optional[List[str]] = Field(DEFAULT_PLACE_FIELDS) | |
max_tweets: int = DEFAULT_MAX_TWEETS | |
cred_info: TwitterCredentials = Field(None) | |
credential: Optional[TwitterCredentials] = None | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
self.cred_info = self.cred_info or TwitterCredentials() | |
if self.credential is not None: | |
logger.warning("`credential` is deprecated; use `cred_info`") | |
self.cred_info = self.credential | |
if self.cred_info.bearer_token.get_secret_value() == '': | |
if self.cred_info.consumer_key.get_secret_value() == '' \ | |
or self.cred_info.consumer_secret.get_secret_value() == '': | |
raise AttributeError( | |
"consumer_key and consumer_secret required to generate bearer_token via Twitter" | |
) | |
self.cred_info.bearer_token = SecretStr(self.generate_bearer_token()) | |
if self.max_tweets > 100: | |
logger.warning("Twitter API support max 100 tweets per call, hence resetting `max_tweets` to 100") | |
self.max_tweets = 100 | |
def get_twitter_credentials(self) -> Dict[str, Any]: | |
if self.cred_info.bearer_token.get_secret_value() == '': | |
self.cred_info.bearer_token = SecretStr(self.generate_bearer_token()) | |
return { | |
"bearer_token": self.cred_info.bearer_token.get_secret_value(), | |
"endpoint": self.cred_info.endpoint, | |
"extra_headers_dict": self.cred_info.extra_headers_dict, | |
} | |
# Copied from Twitter searchtweets-v2 lib | |
def generate_bearer_token(self) -> str: | |
""" | |
Return the bearer token for a given pair of consumer key and secret values. | |
""" | |
data = [("grant_type", "client_credentials")] | |
resp = requests.post( | |
TWITTER_OAUTH_ENDPOINT, | |
data=data, | |
auth=( | |
self.cred_info.consumer_key.get_secret_value(), | |
self.cred_info.consumer_secret.get_secret_value(), | |
), | |
) | |
logger.warning("Grabbing bearer token from OAUTH") | |
if resp.status_code >= 400: | |
logger.error(resp.text) | |
resp.raise_for_status() | |
return str(resp.json()["access_token"]) | |
class TwitterSource(BaseSource): | |
NAME: str = "Twitter" | |
def lookup(self, config: TwitterSourceConfig, **kwargs: Any) -> List[TextPayload]: # type: ignore[override] | |
if ( | |
not config.query | |
and not config.keywords | |
and not config.hashtags | |
and not config.usernames | |
): | |
raise AttributeError( | |
"At least one non empty parameter required (query, keywords, hashtags, and usernames)" | |
) | |
place_fields = ( | |
",".join(config.place_fields) if config.place_fields is not None else None | |
) | |
user_fields = ( | |
",".join(config.user_fields) if config.user_fields is not None else None | |
) | |
expansions = ( | |
",".join(config.expansions) if config.expansions is not None else None | |
) | |
tweet_fields = ( | |
",".join(config.tweet_fields) if config.tweet_fields is not None else None | |
) | |
# Get data from state | |
identifier: str = kwargs.get("id", None) | |
state: Optional[Dict[str, Any]] = ( | |
None | |
if identifier is None or self.store is None | |
else self.store.get_source_state(identifier) | |
) | |
since_id: Optional[int] = ( | |
config.since_id or None if state is None else state.get("since_id", None) | |
) | |
until_id: Optional[int] = ( | |
config.until_id or None if state is None else state.get("until_id", None) | |
) | |
update_state: bool = True if identifier else False | |
state = state or dict() | |
max_tweet_id = since_id | |
lookup_period = config.lookup_period | |
if lookup_period is None: | |
start_time = None | |
elif len(lookup_period) <= 5: | |
start_time = convert_utc_time(lookup_period).replace(tzinfo=pytz.UTC) | |
else: | |
start_time = datetime.strptime(lookup_period, "%Y-%m-%dT%H:%M:%S%z") | |
if since_id or until_id: | |
lookup_period = None | |
query = self._generate_query_string( | |
query=config.query, | |
keywords=config.keywords, | |
hashtags=config.hashtags, | |
usernames=config.usernames, | |
operators=config.operators, | |
) | |
source_responses: List[TextPayload] = [] | |
search_query = gen_request_parameters( | |
granularity=None, | |
query=query, | |
results_per_call=config.max_tweets, | |
place_fields=place_fields, | |
expansions=expansions, | |
user_fields=user_fields, | |
tweet_fields=tweet_fields, | |
since_id=since_id, | |
until_id=until_id, | |
start_time=lookup_period, | |
stringify=False, | |
) | |
logger.info(search_query) | |
tweets_output = collect_results( | |
query=search_query, | |
max_tweets=config.max_tweets, | |
result_stream_args=config.get_twitter_credentials(), | |
) | |
tweets: List[Dict[str, Any]] = [] | |
users: List[Dict[str, Any]] = [] | |
meta_info: Dict[str, Any] = {} | |
if not tweets_output and len(tweets_output) == 0: | |
logger.info("No Tweets found") | |
else: | |
tweets = tweets_output[0]["data"] if "data" in tweets_output[0] else tweets | |
if "includes" in tweets_output[0] and "users" in tweets_output[0]["includes"]: | |
users = tweets_output[0]["includes"]["users"] | |
meta_info = tweets_output[0]["meta"] if "meta" in tweets_output[0] else meta_info | |
# Extract user info and create user map | |
user_map: Dict[str, Dict[str, Any]] = {} | |
if len(users) > 0 and "id" in users[0]: | |
for user in users: | |
if "username" in user: | |
user["user_url"] = f'https://twitter.com/{user["username"]}' | |
user_map[user["id"]] = user | |
logger.info(f"Twitter API meta_info='{meta_info}'") | |
for tweet in tweets: | |
if "author_id" in tweet and tweet["author_id"] in user_map: | |
tweet["author_info"] = user_map.get(tweet["author_id"]) | |
source_responses.append(self._get_source_output(tweet)) | |
if start_time: | |
created_date = datetime.strptime( | |
tweet["created_at"], "%Y-%m-%dT%H:%M:%S.%f%z" | |
) | |
if start_time > created_date: | |
break | |
max_tweet_id = meta_info["newest_id"] if "newest_id" in meta_info else max_tweet_id | |
# min_tweet_id = meta_info["oldest_id"] if "oldest_id" in meta_info else min_tweet_id | |
if update_state and self.store is not None: | |
state["since_id"] = max_tweet_id | |
self.store.update_source_state(workflow_id=identifier, state=state) | |
return source_responses | |
def _generate_query_string( | |
query: Optional[str] = None, | |
keywords: Optional[List[str]] = None, | |
hashtags: Optional[List[str]] = None, | |
usernames: Optional[List[str]] = None, | |
operators: Optional[List[str]] = None, | |
) -> str: | |
if query: | |
return query | |
or_tokens = [] | |
and_tokens = [] | |
or_tokens_list = [keywords, hashtags, usernames] | |
for tokens in or_tokens_list: | |
if tokens: | |
if len(tokens) > 0: | |
or_tokens.append(f'({" OR ".join(tokens)})') | |
else: | |
or_tokens.append(f'{"".join(tokens)}') | |
and_query_str = "" | |
or_query_str = "" | |
if or_tokens: | |
if len(or_tokens) > 0: | |
or_query_str = f'{" OR ".join(or_tokens)}' | |
else: | |
or_query_str = f'{"".join(or_tokens)}' | |
if operators: | |
and_tokens.append(f'{" ".join(operators)}') | |
if and_tokens: | |
and_query_str = f' ({" ".join(and_tokens)})' if and_tokens else "" | |
return or_query_str + and_query_str | |
def _get_source_output(self, tweet: Dict[str, Any]) -> TextPayload: | |
tweet["tweet_url"] = f'https://twitter.com/twitter/statuses/{tweet["id"]}' | |
return TextPayload( | |
processed_text=tweet["text"], meta=tweet, source_name=self.NAME | |
) | |