# src/datatonic/dataloader.py from datasets import load_dataset import json class DataLoader: def __init__(self): self.datasets = { "gpl-fiqa": self.load_gpl_fiqa, "msmarco": self.load_msmarco, "nfcorpus": self.load_nfcorpus, "covid19": self.load_covid19, "gpl-webis-touche2020": self.load_gpl_webis_touche2020, "gpl-hotpotqa": self.load_gpl_hotpotqa, "gpl-nq": self.load_gpl_nq, "gpl-fever": self.load_gpl_fever, "gpl-scidocs": self.load_gpl_scidocs, "gpl-scifact": self.load_gpl_scifact, "gpl-cqadupstack": self.load_gpl_cqadupstack, "gpl-arguana": self.load_gpl_arguana, "gpl-climate-fever": self.load_gpl_climate_fever, "gpl-dbpedia-entity": self.load_gpl_dbpedia_entity, "gpl-all-mix-450k": self.load_gpl_all_mix_450k, } def load_dataset_generic(self, dataset_name): dataset = load_dataset(dataset_name) return self.process_dataset(dataset) def load_gpl_fiqa(self): return self.load_dataset_generic("nthakur/gpl-fiqa") def load_msmarco(self): return self.load_dataset_generic("nthakur/msmarco-passage-sampled-100k") def load_nfcorpus(self): return self.load_dataset_generic("nthakur/gpl-nfcorpus") def load_covid19(self): return self.load_dataset_generic("nthakur/gpl-trec-covid") def load_gpl_webis_touche2020(self): return self.load_dataset_generic("nthakur/gpl-webis-touche2020") def load_gpl_hotpotqa(self): return self.load_dataset_generic("nthakur/gpl-hotpotqa") def load_gpl_nq(self): return self.load_dataset_generic("nthakur/gpl-nq") def load_gpl_fever(self): return self.load_dataset_generic("nthakur/gpl-fever") def load_gpl_scidocs(self): return self.load_dataset_generic("nthakur/gpl-scidocs") def load_gpl_scifact(self): return self.load_dataset_generic("nthakur/gpl-scifact") def load_gpl_cqadupstack(self): return self.load_dataset_generic("nthakur/gpl-cqadupstack") def load_gpl_arguana(self): return self.load_dataset_generic("nthakur/gpl-arguana") def load_gpl_climate_fever(self): return self.load_dataset_generic("nthakur/gpl-climate-fever") def load_gpl_dbpedia_entity(self): return self.load_dataset_generic("nthakur/gpl-dbpedia-entity") def load_gpl_all_mix_450k(self): return self.load_dataset_generic("nthakur/gpl-all-mix-450k") def process_dataset(self, dataset): # Process the dataset to fit the required JSON structure processed_data = [] for entry in dataset['train']: # Adjust the processing based on the actual structure of each dataset processed_entry = { "query": entry.get("query", ""), "positive_passages": entry.get("positive_passages", []), "negative_passages": entry.get("negative_passages", []) } processed_data.append(processed_entry) return processed_data def load_and_process(self, dataset_name): if dataset_name in self.datasets: return self.datasets[dataset_name]() else: # Log or return an error message and default to "gpl-arguana" error_message = f"Dataset '{dataset_name}' not supported. Defaulting to 'gpl-arguana'." print(error_message) # or handle this message as needed return self.load_gpl_arguana() # Default to the 'gpl-arguana' dataset def save_to_json(self, data, file_name): with open(file_name, 'w') as f: json.dump(data, f, indent=4)