Spaces:
Sleeping
Sleeping
import os | |
import feedparser | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.docstore.document import Document | |
import logging | |
from huggingface_hub import HfApi, login | |
import shutil | |
import rss_feeds | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MAX_ARTICLES_PER_FEED = 5 # Set to 5 for testing, increase later as needed | |
LOCAL_DB_DIR = "chroma_db" | |
RSS_FEEDS = rss_feeds.RSS_FEEDS | |
HF_API_TOKEN = os.getenv("DEMO_HF_API_TOKEN", "YOUR_HF_API_TOKEN") | |
REPO_ID = "broadfield-dev/news-rag-db" | |
# Initialize Hugging Face API | |
login(token=HF_API_TOKEN) | |
hf_api = HfApi() | |
# Initialize embedding model and vector DB | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vector_db = Chroma(persist_directory=LOCAL_DB_DIR, embedding_function=embedding_model) | |
def fetch_rss_feeds(): | |
articles = [] | |
seen_keys = set() | |
for feed_url in RSS_FEEDS: | |
try: | |
logger.info(f"Fetching {feed_url}") | |
feed = feedparser.parse(feed_url) | |
if feed.bozo: | |
logger.warning(f"Parse error for {feed_url}: {feed.bozo_exception}") | |
continue | |
article_count = 0 | |
for entry in feed.entries: | |
if article_count >= MAX_ARTICLES_PER_FEED: | |
break | |
title = entry.get("title", "No Title").strip() | |
link = entry.get("link", "").strip() | |
description = entry.get("summary", entry.get("description", "No Description")) | |
published = entry.get("published", "Unknown Date").strip() | |
key = f"{title}|{link}|{published}" | |
if key not in seen_keys: | |
seen_keys.add(key) | |
image = (entry.get("media_content", [{}])[0].get("url") or | |
entry.get("media_thumbnail", [{}])[0].get("url") or "svg") | |
articles.append({ | |
"title": title, | |
"link": link, | |
"description": description, | |
"published": published, | |
"category": categorize_feed(feed_url), | |
"image": image, | |
}) | |
article_count += 1 | |
except Exception as e: | |
logger.error(f"Error fetching {feed_url}: {e}") | |
logger.info(f"Total articles fetched: {len(articles)}") | |
return articles | |
def categorize_feed(url): | |
if "nature" in url or "science.org" in url or "arxiv.org" in url or "plos.org" in url or "annualreviews.org" in url or "journals.uchicago.edu" in url or "jneurosci.org" in url or "cell.com" in url or "nejm.org" in url or "lancet.com" in url: | |
return "Academic Papers" | |
elif "reuters.com/business" in url or "bloomberg.com" in url or "ft.com" in url or "marketwatch.com" in url or "cnbc.com" in url or "foxbusiness.com" in url or "wsj.com" in url or "bworldonline.com" in url or "economist.com" in url or "forbes.com" in url: | |
return "Business" | |
elif "investing.com" in url or "cnbc.com/market" in url or "marketwatch.com/market" in url or "fool.co.uk" in url or "zacks.com" in url or "seekingalpha.com" in url or "barrons.com" in url or "yahoofinance.com" in url: | |
return "Stocks & Markets" | |
elif "whitehouse.gov" in url or "state.gov" in url or "commerce.gov" in url or "transportation.gov" in url or "ed.gov" in url or "dol.gov" in url or "justice.gov" in url or "federalreserve.gov" in url or "occ.gov" in url or "sec.gov" in url or "bls.gov" in url or "usda.gov" in url or "gao.gov" in url or "cbo.gov" in url or "fema.gov" in url or "defense.gov" in url or "hhs.gov" in url or "energy.gov" in url or "interior.gov" in url: | |
return "Federal Government" | |
elif "weather.gov" in url or "metoffice.gov.uk" in url or "accuweather.com" in url or "weatherunderground.com" in url or "noaa.gov" in url or "wunderground.com" in url or "climate.gov" in url or "ecmwf.int" in url or "bom.gov.au" in url: | |
return "Weather" | |
elif "data.worldbank.org" in url or "imf.org" in url or "un.org" in url or "oecd.org" in url or "statista.com" in url or "kff.org" in url or "who.int" in url or "cdc.gov" in url or "bea.gov" in url or "census.gov" in url or "fdic.gov" in url: | |
return "Data & Statistics" | |
elif "nasa" in url or "spaceweatherlive" in url or "space" in url or "universetoday" in url or "skyandtelescope" in url or "esa" in url: | |
return "Space" | |
elif "sciencedaily" in url or "quantamagazine" in url or "smithsonianmag" in url or "popsci" in url or "discovermagazine" in url or "scientificamerican" in url or "newscientist" in url or "livescience" in url or "atlasobscura" in url: | |
return "Science" | |
elif "wired" in url or "techcrunch" in url or "arstechnica" in url or "gizmodo" in url or "theverge" in url: | |
return "Tech" | |
elif "horoscope" in url or "astrostyle" in url: | |
return "Astrology" | |
elif "cnn_allpolitics" in url or "bbci.co.uk/news/politics" in url or "reuters.com/arc/outboundfeeds/newsletter-politics" in url or "politico.com/rss/politics" in url or "thehill" in url: | |
return "Politics" | |
elif "weather" in url or "swpc.noaa.gov" in url or "foxweather" in url: | |
return "Earth Weather" | |
elif "vogue" in url: | |
return "Lifestyle" | |
elif "phys.org" in url or "aps.org" in url or "physicsworld" in url: | |
return "Physics" | |
return "Uncategorized" | |
def process_and_store_articles(articles): | |
documents = [] | |
for article in articles: | |
try: | |
metadata = { | |
"title": article["title"], | |
"link": article["link"], | |
"original_description": article["description"], | |
"published": article["published"], | |
"category": article["category"], | |
"image": article["image"], | |
} | |
doc = Document(page_content=article["description"], metadata=metadata) | |
documents.append(doc) | |
except Exception as e: | |
logger.error(f"Error processing article {article['title']}: {e}") | |
if documents: | |
try: | |
vector_db.add_documents(documents) | |
logger.info(f"Stored {len(documents)} articles in DB") | |
except Exception as e: | |
logger.error(f"Error storing articles: {e}") | |
def download_from_hf_hub(): | |
if os.path.exists(LOCAL_DB_DIR): | |
shutil.rmtree(LOCAL_DB_DIR) | |
try: | |
hf_api.create_repo(repo_id=REPO_ID, repo_type="dataset", exist_ok=True, token=HF_API_TOKEN) | |
logger.info(f"Downloading Chroma DB from {REPO_ID}...") | |
hf_api.download_repo(repo_id=REPO_ID, repo_type="dataset", local_dir=LOCAL_DB_DIR, token=HF_API_TOKEN) | |
except Exception as e: | |
logger.error(f"Error downloading from Hugging Face Hub: {e}") | |
raise | |
def upload_to_hf_hub(): | |
if os.path.exists(LOCAL_DB_DIR): | |
try: | |
hf_api.create_repo(repo_id=REPO_ID, repo_type="dataset", exist_ok=True, token=HF_API_TOKEN) | |
logger.info(f"Uploading Chroma DB to {REPO_ID}...") | |
for root, _, files in os.walk(LOCAL_DB_DIR): | |
for file in files: | |
local_path = os.path.join(root, file) | |
remote_path = os.path.relpath(local_path, LOCAL_DB_DIR) | |
hf_api.upload_file( | |
path_or_fileobj=local_path, | |
path_in_repo=remote_path, | |
repo_id=REPO_ID, | |
repo_type="dataset", | |
token=HF_API_TOKEN | |
) | |
logger.info(f"Database uploaded to: {REPO_ID}") | |
except Exception as e: | |
logger.error(f"Error uploading to Hugging Face Hub: {e}") | |
raise | |
if __name__ == "__main__": | |
articles = fetch_rss_feeds() | |
process_and_store_articles(articles) | |
upload_to_hf_hub() |