Spaces:
Sleeping
Sleeping
import os | |
import feedparser | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.docstore.document import Document | |
import logging | |
from huggingface_hub import HfApi, login | |
import shutil | |
import rss_feeds | |
from datetime import datetime | |
import dateutil.parser | |
import hashlib | |
import re # For cleaning HTML and whitespace | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MAX_ARTICLES_PER_FEED = 5 | |
LOCAL_DB_DIR = "chroma_db" | |
RSS_FEEDS = rss_feeds.RSS_FEEDS | |
COLLECTION_NAME = "news_articles" | |
HF_API_TOKEN = os.getenv("DEMO_HF_API_TOKEN", "YOUR_HF_API_TOKEN") | |
REPO_ID = "broadfield-dev/news-rag-db" | |
# Initialize Hugging Face API | |
login(token=HF_API_TOKEN) | |
hf_api = HfApi() | |
# Initialize embedding model (global, reusable) | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Initialize vector DB with a specific collection name | |
vector_db = Chroma( | |
persist_directory=LOCAL_DB_DIR, | |
embedding_function=embedding_model, | |
collection_name=COLLECTION_NAME | |
) | |
def clean_text(text): | |
"""Clean text by removing HTML tags and extra whitespace.""" | |
if not text or not isinstance(text, str): | |
return "" | |
# Remove HTML tags | |
text = re.sub(r'<.*?>', '', text) | |
# Normalize whitespace (remove extra spaces, newlines, tabs) | |
text = ' '.join(text.split()) | |
return text.strip().lower() | |
def fetch_rss_feeds(): | |
articles = [] | |
seen_keys = set() | |
for feed_url in RSS_FEEDS: | |
try: | |
logger.info(f"Fetching {feed_url}") | |
feed = feedparser.parse(feed_url) | |
if feed.bozo: | |
logger.warning(f"Parse error for {feed_url}: {feed.bozo_exception}") | |
continue | |
article_count = 0 | |
for entry in feed.entries: | |
if article_count >= MAX_ARTICLES_PER_FEED: | |
break | |
title = entry.get("title", "No Title") | |
link = entry.get("link", "") | |
description = entry.get("summary", entry.get("description", "")) | |
# Clean and normalize all text fields | |
title = clean_text(title) | |
link = clean_text(link) | |
description = clean_text(description) | |
# Try multiple date fields and parse flexibly | |
published = "Unknown Date" | |
for date_field in ["published", "updated", "created", "pubDate"]: # Added "pubDate" for broader compatibility | |
if date_field in entry: | |
try: | |
parsed_date = dateutil.parser.parse(entry[date_field]) | |
published = parsed_date.strftime("%Y-%m-%d %H:%M:%S") | |
break | |
except (ValueError, TypeError) as e: | |
logger.debug(f"Failed to parse {date_field} '{entry[date_field]}': {e}") | |
continue | |
# Use a robust key for deduplication, including cleaned fields | |
description_hash = hashlib.sha256(description.encode('utf-8')).hexdigest() # Switched to SHA256 for better uniqueness | |
key = f"{title}|{link}|{published}|{description_hash}" | |
if key not in seen_keys: | |
seen_keys.add(key) | |
# Try multiple image sources | |
image = "svg" # Default fallback | |
for img_source in [ | |
lambda e: clean_text(e.get("media_content", [{}])[0].get("url")) if e.get("media_content") else "", | |
lambda e: clean_text(e.get("media_thumbnail", [{}])[0].get("url")) if e.get("media_thumbnail") else "", | |
lambda e: clean_text(e.get("enclosure", {}).get("url")) if e.get("enclosure") else "", | |
lambda e: clean_text(next((lnk.get("href") for lnk in e.get("links", []) if lnk.get("type", "").startswith("image")), "")), | |
]: | |
try: | |
img = img_source(entry) | |
if img and img.strip(): | |
image = img | |
break | |
except (IndexError, AttributeError, TypeError): | |
continue | |
articles.append({ | |
"title": title, | |
"link": link, | |
"description": description, | |
"published": published, | |
"category": categorize_feed(feed_url), | |
"image": image, | |
}) | |
article_count += 1 | |
else: | |
logger.debug(f"Duplicate article skipped in feed {feed_url}: {key}") | |
except Exception as e: | |
logger.error(f"Error fetching {feed_url}: {e}") | |
logger.info(f"Total articles fetched: {len(articles)}") | |
return articles | |
def categorize_feed(url): | |
if "nature" in url.lower() or "science.org" in url.lower() or "arxiv.org" in url.lower() or "plos.org" in url.lower() or "annualreviews.org" in url.lower() or "journals.uchicago.edu" in url.lower() or "jneurosci.org" in url.lower() or "cell.com" in url.lower() or "nejm.org" in url.lower() or "lancet.com" in url.lower(): | |
return "Academic Papers" | |
elif "reuters.com/business" in url.lower() or "bloomberg.com" in url.lower() or "ft.com" in url.lower() or "marketwatch.com" in url.lower() or "cnbc.com" in url.lower() or "foxbusiness.com" in url.lower() or "wsj.com" in url.lower() or "bworldonline.com" in url.lower() or "economist.com" in url.lower() or "forbes.com" in url.lower(): | |
return "Business" | |
elif "investing.com" in url.lower() or "cnbc.com/market" in url.lower() or "marketwatch.com/market" in url.lower() or "fool.co.uk" in url.lower() or "zacks.com" in url.lower() or "seekingalpha.com" in url.lower() or "barrons.com" in url.lower() or "yahoofinance.com" in url.lower(): | |
return "Stocks & Markets" | |
elif "whitehouse.gov" in url.lower() or "state.gov" in url.lower() or "commerce.gov" in url.lower() or "transportation.gov" in url.lower() or "ed.gov" in url.lower() or "dol.gov" in url.lower() or "justice.gov" in url.lower() or "federalreserve.gov" in url.lower() or "occ.gov" in url.lower() or "sec.gov" in url.lower() or "bls.gov" in url.lower() or "usda.gov" in url.lower() or "gao.gov" in url.lower() or "cbo.gov" in url.lower() or "fema.gov" in url.lower() or "defense.gov" in url.lower() or "hhs.gov" in url.lower() or "energy.gov" in url.lower() or "interior.gov" in url.lower(): | |
return "Federal Government" | |
elif "weather.gov" in url.lower() or "metoffice.gov.uk" in url.lower() or "accuweather.com" in url.lower() or "weatherunderground.com" in url.lower() or "noaa.gov" in url.lower() or "wunderground.com" in url.lower() or "climate.gov" in url.lower() or "ecmwf.int" in url.lower() or "bom.gov.au" in url.lower(): | |
return "Weather" | |
elif "data.worldbank.org" in url.lower() or "imf.org" in url.lower() or "un.org" in url.lower() or "oecd.org" in url.lower() or "statista.com" in url.lower() or "kff.org" in url.lower() or "who.int" in url.lower() or "cdc.gov" in url.lower() or "bea.gov" in url.lower() or "census.gov" in url.lower() or "fdic.gov" in url.lower(): | |
return "Data & Statistics" | |
elif "nasa" in url.lower() or "spaceweatherlive" in url.lower() or "space" in url.lower() or "universetoday" in url.lower() or "skyandtelescope" in url.lower() or "esa" in url.lower(): | |
return "Space" | |
elif "sciencedaily" in url.lower() or "quantamagazine" in url.lower() or "smithsonianmag" in url.lower() or "popsci" in url.lower() or "discovermagazine" in url.lower() or "scientificamerican" in url.lower() or "newscientist" in url.lower() or "livescience" in url.lower() or "atlasobscura" in url.lower(): | |
return "Science" | |
elif "wired" in url.lower() or "techcrunch" in url.lower() or "arstechnica" in url.lower() or "gizmodo" in url.lower() or "theverge" in url.lower(): | |
return "Tech" | |
elif "horoscope" in url.lower() or "astrostyle" in url.lower(): | |
return "Astrology" | |
elif "cnn_allpolitics" in url.lower() or "bbci.co.uk/news/politics" in url.lower() or "reuters.com/arc/outboundfeeds/newsletter-politics" in url.lower() or "politico.com/rss/politics" in url.lower() or "thehill" in url.lower(): | |
return "Politics" | |
elif "weather" in url.lower() or "swpc.noaa.gov" in url.lower() or "foxweather" in url.lower(): | |
return "Earth Weather" | |
elif "vogue" in url.lower(): | |
return "Lifestyle" | |
elif "phys.org" in url.lower() or "aps.org" in url.lower() or "physicsworld" in url.lower(): | |
return "Physics" | |
return "Uncategorized" | |
def process_and_store_articles(articles): | |
documents = [] | |
existing_ids = set(vector_db.get()["ids"]) # Get existing document IDs to avoid duplicates | |
for article in articles: | |
try: | |
# Clean and normalize all fields | |
title = clean_text(article["title"]) | |
link = clean_text(article["link"]) | |
description = clean_text(article["description"]) | |
published = article["published"] | |
description_hash = hashlib.sha256(description.encode('utf-8')).hexdigest() | |
doc_id = f"{title}|{link}|{published}|{description_hash}" | |
if doc_id in existing_ids: | |
logger.debug(f"Skipping duplicate in DB: {doc_id}") | |
continue | |
metadata = { | |
"title": article["title"], | |
"link": article["link"], | |
"original_description": article["description"], | |
"published": article["published"], | |
"category": article["category"], | |
"image": article["image"], | |
} | |
doc = Document(page_content=description, metadata=metadata, id=doc_id) | |
documents.append(doc) | |
except Exception as e: | |
logger.error(f"Error processing article {article['title']}: {e}") | |
if documents: | |
try: | |
vector_db.add_documents(documents) | |
vector_db.persist() # Explicitly persist changes | |
logger.info(f"Added {len(documents)} new articles to DB") | |
except Exception as e: | |
logger.error(f"Error storing articles: {e}") | |
def download_from_hf_hub(): | |
# Only download if the local DB doesn’t exist (initial setup) | |
if not os.path.exists(LOCAL_DB_DIR): | |
try: | |
hf_api.create_repo(repo_id=REPO_ID, repo_type="dataset", exist_ok=True, token=HF_API_TOKEN) | |
logger.info(f"Downloading Chroma DB from {REPO_ID}...") | |
hf_api.download_repo(repo_id=REPO_ID, repo_type="dataset", local_dir=LOCAL_DB_DIR, token=HF_API_TOKEN) | |
except Exception as e: | |
logger.error(f"Error downloading from Hugging Face Hub: {e}") | |
raise | |
else: | |
logger.info("Local Chroma DB already exists, skipping download.") | |
def upload_to_hf_hub(): | |
if os.path.exists(LOCAL_DB_DIR): | |
try: | |
logger.info(f"Uploading updated Chroma DB to {REPO_ID}...") | |
for root, _, files in os.walk(LOCAL_DB_DIR): | |
for file in files: | |
local_path = os.path.join(root, file) | |
remote_path = os.path.relpath(local_path, LOCAL_DB_DIR) | |
hf_api.upload_file( | |
path_or_fileobj=local_path, | |
path_in_repo=remote_path, | |
repo_id=REPO_ID, | |
repo_type="dataset", | |
token=HF_API_TOKEN | |
) | |
logger.info(f"Database uploaded to: {REPO_ID}") | |
except Exception as e: | |
logger.error(f"Error uploading to Hugging Face Hub: {e}") | |
raise | |
if __name__ == "__main__": | |
articles = fetch_rss_feeds() | |
process_and_store_articles(articles) | |
upload_to_hf_hub() |