Spaces:
Runtime error
Runtime error
# src/datatonic/dataloader.py | |
from datasets import load_dataset | |
import json | |
class DataLoader: | |
def __init__(self): | |
self.datasets = { | |
"gpl-fiqa": self.load_gpl_fiqa, | |
"msmarco": self.load_msmarco, | |
"nfcorpus": self.load_nfcorpus, | |
"covid19": self.load_covid19, | |
"gpl-webis-touche2020": self.load_gpl_webis_touche2020, | |
"gpl-hotpotqa": self.load_gpl_hotpotqa, | |
"gpl-nq": self.load_gpl_nq, | |
"gpl-fever": self.load_gpl_fever, | |
"gpl-scidocs": self.load_gpl_scidocs, | |
"gpl-scifact": self.load_gpl_scifact, | |
"gpl-cqadupstack": self.load_gpl_cqadupstack, | |
"gpl-arguana": self.load_gpl_arguana, | |
"gpl-climate-fever": self.load_gpl_climate_fever, | |
"gpl-dbpedia-entity": self.load_gpl_dbpedia_entity, | |
"gpl-all-mix-450k": self.load_gpl_all_mix_450k, | |
} | |
def load_dataset_generic(self, dataset_name): | |
dataset = load_dataset(dataset_name) | |
return self.process_dataset(dataset) | |
def load_gpl_fiqa(self): | |
return self.load_dataset_generic("nthakur/gpl-fiqa") | |
def load_msmarco(self): | |
return self.load_dataset_generic("nthakur/msmarco-passage-sampled-100k") | |
def load_nfcorpus(self): | |
return self.load_dataset_generic("nthakur/gpl-nfcorpus") | |
def load_covid19(self): | |
return self.load_dataset_generic("nthakur/gpl-trec-covid") | |
def load_gpl_webis_touche2020(self): | |
return self.load_dataset_generic("nthakur/gpl-webis-touche2020") | |
def load_gpl_hotpotqa(self): | |
return self.load_dataset_generic("nthakur/gpl-hotpotqa") | |
def load_gpl_nq(self): | |
return self.load_dataset_generic("nthakur/gpl-nq") | |
def load_gpl_fever(self): | |
return self.load_dataset_generic("nthakur/gpl-fever") | |
def load_gpl_scidocs(self): | |
return self.load_dataset_generic("nthakur/gpl-scidocs") | |
def load_gpl_scifact(self): | |
return self.load_dataset_generic("nthakur/gpl-scifact") | |
def load_gpl_cqadupstack(self): | |
return self.load_dataset_generic("nthakur/gpl-cqadupstack") | |
def load_gpl_arguana(self): | |
return self.load_dataset_generic("nthakur/gpl-arguana") | |
def load_gpl_climate_fever(self): | |
return self.load_dataset_generic("nthakur/gpl-climate-fever") | |
def load_gpl_dbpedia_entity(self): | |
return self.load_dataset_generic("nthakur/gpl-dbpedia-entity") | |
def load_gpl_all_mix_450k(self): | |
return self.load_dataset_generic("nthakur/gpl-all-mix-450k") | |
def process_dataset(self, dataset): | |
# Process the dataset to fit the required JSON structure | |
processed_data = [] | |
for entry in dataset['train']: | |
# Adjust the processing based on the actual structure of each dataset | |
processed_entry = { | |
"query": entry.get("query", ""), | |
"positive_passages": entry.get("positive_passages", []), | |
"negative_passages": entry.get("negative_passages", []) | |
} | |
processed_data.append(processed_entry) | |
return processed_data | |
def load_and_process(self, dataset_name): | |
if dataset_name in self.datasets: | |
return self.datasets[dataset_name]() | |
else: | |
# Log or return an error message and default to "gpl-arguana" | |
error_message = f"Dataset '{dataset_name}' not supported. Defaulting to 'gpl-arguana'." | |
print(error_message) # or handle this message as needed | |
return self.load_gpl_arguana() # Default to the 'gpl-arguana' dataset | |
def save_to_json(self, data, file_name): | |
with open(file_name, 'w') as f: | |
json.dump(data, f, indent=4) |