green-city-finder / src /helpers /data_loaders.py
Ashmi Banerjee
updates to the s-fairness calculation and refactoring code duplication
ac20456
raw
history blame
1.55 kB
from datasets import load_dataset
from dotenv import load_dotenv
from datasets import DatasetDict
import os
import pandas as pd
from typing import Optional
load_dotenv()
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
HF_TOKEN = os.environ["HF_TOKEN"]
def load_data_hf(repo_name: str, data_files: str, is_public: bool) -> DatasetDict:
if is_public:
dataset = load_dataset(repo_name, split="train")
else:
dataset = load_dataset(repo_name, token=True, data_files=data_files)
return dataset
def load_scores(category: str) -> pd.DataFrame | None:
repository = os.environ.get("DATA_REPO")
data_file = None
match category:
case "popularity":
data_file = "computed/popularity/popularity_scores.csv"
case "seasonality":
data_file = "computed/seasonality/seasonality_scores.csv"
case "emissions":
data_file = "computed/emissions/emissions_merged.csv"
case _:
logger.info(f"Invalid category: {category}")
if data_file: # only for valid categories
data = load_data_hf(repository, data_file, is_public=False)
df = pd.DataFrame(data["train"][:])
return df
return None
def load_places(data_file: str) -> pd.DataFrame | None:
repository = os.environ.get("DATA_REPO")
if data_file:
data = load_data_hf(repository, data_file, is_public=False)
df = pd.DataFrame(data["train"][:])
return df
return None