community-science-progress / load_dataframe.py
nielsr's picture
nielsr HF staff
First draft
57c87c9
raw
history blame
3.78 kB
import dataclasses
from multiprocessing import cpu_count
import tqdm
import requests
import streamlit as st
import pandas as pd
from datasets import Dataset, load_dataset
from paperswithcode import PapersWithCodeClient
@dataclasses.dataclass(frozen=True)
class PaperInfo:
date: str
arxiv_id: str
github: str
title: str
paper_page: str
upvotes: int
num_comments: int
def get_df() -> pd.DataFrame:
df = pd.merge(
left=load_dataset("hysts-bot-data/daily-papers", split="train").to_pandas(),
right=load_dataset("hysts-bot-data/daily-papers-stats", split="train").to_pandas(),
on="arxiv_id",
)
df = df[::-1].reset_index(drop=True)
paper_info = []
for _, row in tqdm.auto.tqdm(df.iterrows(), total=len(df)):
info = PaperInfo(
**row,
paper_page=f"https://huggingface.co/papers/{row.arxiv_id}",
)
paper_info.append(info)
return pd.DataFrame([dataclasses.asdict(info) for info in paper_info])
def get_github_url(client: PapersWithCodeClient, paper_title: str) -> str:
"""
Get the Github URL for a paper.
"""
repo_url = ""
try:
# get paper ID
results = client.paper_list(q=paper_title).results
paper_id = results[0].id
# get paper
paper = client.paper_get(paper_id=paper_id)
# get repositories
repositories = client.paper_repository_list(paper_id=paper.id).results
for repo in repositories:
if repo.is_official:
repo_url = repo.url
except:
pass
return repo_url
def add_metadata_batch(batch, client: PapersWithCodeClient):
"""
Add metadata to a batch of papers.
"""
# get Github URLs for all papers in the batch
github_urls = []
for paper_title in batch["title"]:
github_url = get_github_url(client, paper_title)
github_urls.append(github_url)
# overwrite the Github links
batch["github"] = github_urls
return batch
def add_hf_assets(batch):
"""
Add Hugging Face assets to a batch of papers.
"""
num_spaces = []
num_models = []
num_datasets = []
for arxiv_id in batch["arxiv_id"]:
if arxiv_id != "":
response = requests.get(f"https://huggingface.co/api/arxiv/{arxiv_id}/repos")
result = response.json()
num_spaces_example = len(result["spaces"])
num_models_example = len(result["models"])
num_datasets_example = len(result["datasets"])
else:
num_spaces_example = 0
num_models_example = 0
num_datasets_example = 0
num_spaces.append(num_spaces_example)
num_models.append(num_models_example)
num_datasets.append(num_datasets_example)
batch["num_models"] = num_models
batch["num_datasets"] = num_datasets
batch["num_spaces"] = num_spaces
return batch
@st.cache_data
def get_data() -> pd.DataFrame:
"""
Load the dataset and enrich it with metadata.
"""
# step 1. load as Pandas dataframe
df = get_df()
df['date'] = pd.to_datetime(df['date'])
# step 2. enrich using PapersWithCode API
dataset = Dataset.from_pandas(df)
# TODO remove
# dataset = dataset.select(range(10))
dataset = dataset.map(add_metadata_batch, batched=True, batch_size=4, num_proc=cpu_count(), fn_kwargs={"client": PapersWithCodeClient()})
# step 3. enrich using Hugging Face API
dataset = dataset.map(add_hf_assets, batched=True, batch_size=4, num_proc=cpu_count())
# return as Pandas dataframe
dataframe = dataset.to_pandas()
# convert date column to datetime
dataframe['date'] = pd.to_datetime(dataframe['date'])
print("First few rows of the dataset:")
print(dataframe.head())
return dataframe