File size: 3,758 Bytes
e7ece9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59c3706
 
 
 
e7ece9c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# src/datatonic/dataloader.py

from datasets import load_dataset
import json

class DataLoader:
    def __init__(self):
        self.datasets = {
            "gpl-fiqa": self.load_gpl_fiqa,
            "msmarco": self.load_msmarco,
            "nfcorpus": self.load_nfcorpus,
            "covid19": self.load_covid19,
            "gpl-webis-touche2020": self.load_gpl_webis_touche2020,
            "gpl-hotpotqa": self.load_gpl_hotpotqa,
            "gpl-nq": self.load_gpl_nq,
            "gpl-fever": self.load_gpl_fever,
            "gpl-scidocs": self.load_gpl_scidocs,
            "gpl-scifact": self.load_gpl_scifact,
            "gpl-cqadupstack": self.load_gpl_cqadupstack,
            "gpl-arguana": self.load_gpl_arguana,
            "gpl-climate-fever": self.load_gpl_climate_fever,
            "gpl-dbpedia-entity": self.load_gpl_dbpedia_entity,
            "gpl-all-mix-450k": self.load_gpl_all_mix_450k,
        }

    def load_dataset_generic(self, dataset_name):
        dataset = load_dataset(dataset_name)
        return self.process_dataset(dataset)

    def load_gpl_fiqa(self):
        return self.load_dataset_generic("nthakur/gpl-fiqa")

    def load_msmarco(self):
        return self.load_dataset_generic("nthakur/msmarco-passage-sampled-100k")

    def load_nfcorpus(self):
        return self.load_dataset_generic("nthakur/gpl-nfcorpus")

    def load_covid19(self):
        return self.load_dataset_generic("nthakur/gpl-trec-covid")

    def load_gpl_webis_touche2020(self):
        return self.load_dataset_generic("nthakur/gpl-webis-touche2020")

    def load_gpl_hotpotqa(self):
        return self.load_dataset_generic("nthakur/gpl-hotpotqa")

    def load_gpl_nq(self):
        return self.load_dataset_generic("nthakur/gpl-nq")

    def load_gpl_fever(self):
        return self.load_dataset_generic("nthakur/gpl-fever")

    def load_gpl_scidocs(self):
        return self.load_dataset_generic("nthakur/gpl-scidocs")

    def load_gpl_scifact(self):
        return self.load_dataset_generic("nthakur/gpl-scifact")

    def load_gpl_cqadupstack(self):
        return self.load_dataset_generic("nthakur/gpl-cqadupstack")

    def load_gpl_arguana(self):
        return self.load_dataset_generic("nthakur/gpl-arguana")

    def load_gpl_climate_fever(self):
        return self.load_dataset_generic("nthakur/gpl-climate-fever")

    def load_gpl_dbpedia_entity(self):
        return self.load_dataset_generic("nthakur/gpl-dbpedia-entity")

    def load_gpl_all_mix_450k(self):
        return self.load_dataset_generic("nthakur/gpl-all-mix-450k")

    def process_dataset(self, dataset):
        # Process the dataset to fit the required JSON structure
        processed_data = []
        for entry in dataset['train']:
            # Adjust the processing based on the actual structure of each dataset
            processed_entry = {
                "query": entry.get("query", ""),
                "positive_passages": entry.get("positive_passages", []),
                "negative_passages": entry.get("negative_passages", [])
            }
            processed_data.append(processed_entry)
        return processed_data

    def load_and_process(self, dataset_name):
        if dataset_name in self.datasets:
            return self.datasets[dataset_name]()
        else:
            # Log or return an error message and default to "gpl-arguana"
            error_message = f"Dataset '{dataset_name}' not supported. Defaulting to 'gpl-arguana'."
            print(error_message)  # or handle this message as needed
            return self.load_gpl_arguana()  # Default to the 'gpl-arguana' dataset

    def save_to_json(self, data, file_name):
        with open(file_name, 'w') as f:
            json.dump(data, f, indent=4)