Spaces:
Sleeping
Sleeping
File size: 4,530 Bytes
926675f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import datasets
import pandas as pd
class DataLoader:
def __init__(self, data="hotpot_qa", seed=2023):
self.data = data
self.seed = seed
def load(self, sample_size=None, type="train"):
if self.data == "hotpot_qa":
return self.load_hotpot_qa(sample_size=sample_size, type=type)
elif self.data == "fever":
return self.load_fever(sample_size=sample_size, type=type)
elif self.data == "trivia_qa":
return self.load_trivia_qa(sample_size=sample_size, type=type)
elif self.data == "gsm8k":
return self.load_gsm8k(sample_size=sample_size, type=type)
elif self.data == "physics_question":
return self.load_physics_question(sample_size=sample_size)
elif self.data == "disfl_qa":
return self.load_disfl_qa(sample_size=sample_size)
elif self.data == "sports_understanding":
return self.load_sports_understanding(sample_size=sample_size)
elif self.data == "strategy_qa":
return self.load_strategy_qa(sample_size=sample_size)
elif self.data == "sotu_qa":
return self.load_sotu_qa(sample_size=sample_size)
else:
raise ValueError("Data not supported.")
def load_hotpot_qa(self, cache_dir="data/hotpot_qa", sample_size=100, type="test"):
assert type in ["train", "validation", "test"]
data = datasets.load_dataset('hotpot_qa', 'fullwiki', cache_dir=cache_dir)
df = data[type].to_pandas()
sampled_df = df.sample(sample_size, random_state=self.seed)[["question", "answer"]].reset_index(drop=True)
return sampled_df
def load_fever(self, cache_dir="data/fever", sample_size=100, type="test"):
assert type in ["train", "validation", "test"]
data = datasets.load_dataset('copenlu/fever_gold_evidence', cache_dir=cache_dir)
df = data[type].to_pandas()
sampled_df = df.sample(sample_size, random_state=self.seed)[["claim", "label"]].reset_index(drop=True)
return sampled_df
def load_trivia_qa(self, cache_dir="data/trivia_qa", sample_size=100, type="test"):
assert type in ["train", "validation", "test"]
data = datasets.load_dataset('trivia_qa', 'rc.nocontext', cache_dir=cache_dir)
df = data[type].to_pandas()
sampled_df = df.sample(sample_size, random_state=self.seed)[["question", "answer"]].reset_index(drop=True)
return sampled_df
def load_gsm8k(self, cache_dir="data/gsm8k", sample_size=100, type="test"):
assert type in ["train", "validation", "test"]
data = datasets.load_dataset('gsm8k', name="main", cache_dir=cache_dir)
df = data[type].to_pandas()
sampled_df = df.sample(sample_size, random_state=self.seed)[["question", "answer"]].reset_index(drop=True)
return sampled_df
def load_physics_question(self, cache_dir="data/bigbench/physics_question.csv", sample_size=None):
df = pd.read_csv(cache_dir)
if sample_size is not None:
sampled_df = df.sample(sample_size, random_state=self.seed)[["input", "target"]].reset_index(drop=True)
return sampled_df
return df
def load_sports_understanding(self, cache_dir="data/bigbench/sports_understanding.csv", sample_size=None):
df = pd.read_csv(cache_dir)
if sample_size is not None:
sampled_df = df.sample(sample_size, random_state=self.seed)[["input", "target"]].reset_index(drop=True)
return sampled_df
return df
def load_disfl_qa(self, cache_dir="data/bigbench/disfl_qa.csv", sample_size=None):
df = pd.read_csv(cache_dir)
if sample_size is not None:
sampled_df = df.sample(sample_size, random_state=self.seed)[["input", "target"]].reset_index(drop=True)
return sampled_df
return df
def load_strategy_qa(self, cache_dir="data/bigbench/strategy_qa.csv", sample_size=None):
df = pd.read_csv(cache_dir)
if sample_size is not None:
sampled_df = df.sample(sample_size, random_state=self.seed)[["input", "target"]].reset_index(drop=True)
return sampled_df
return df
def load_sotu_qa(self, cache_dir="data/SOTU/SOTU_QA.csv", sample_size=None):
df = pd.read_csv(cache_dir)
if sample_size is not None:
sampled_df = df.sample(sample_size, random_state=self.seed)[["question", "answer"]].reset_index(drop=True)
return sampled_df
return df |