Spaces:
Running
Running
from datasets import load_dataset | |
from dotenv import load_dotenv | |
from datasets import DatasetDict | |
import os | |
import pandas as pd | |
from typing import Optional | |
load_dotenv() | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG) | |
HF_TOKEN = os.environ["HF_TOKEN"] | |
def load_data_hf(repo_name: str, data_files: str, is_public: bool) -> DatasetDict: | |
if is_public: | |
dataset = load_dataset(repo_name, split="train") | |
else: | |
dataset = load_dataset(repo_name, token=True, data_files=data_files) | |
return dataset | |
def load_scores(category: str) -> pd.DataFrame | None: | |
repository = os.environ.get("DATA_REPO") | |
data_file = None | |
match category: | |
case "popularity": | |
data_file = "computed/popularity/popularity_scores.csv" | |
case "seasonality": | |
data_file = "computed/seasonality/seasonality_scores.csv" | |
case "emissions": | |
data_file = "computed/emissions/emissions_merged.csv" | |
case _: | |
logger.info(f"Invalid category: {category}") | |
if data_file: # only for valid categories | |
data = load_data_hf(repository, data_file, is_public=False) | |
df = pd.DataFrame(data["train"][:]) | |
return df | |
return None | |
def load_places(data_file: str) -> pd.DataFrame | None: | |
repository = os.environ.get("DATA_REPO") | |
if data_file: | |
data = load_data_hf(repository, data_file, is_public=False) | |
df = pd.DataFrame(data["train"][:]) | |
return df | |
return None | |