Data-Contamination-Database / postprocessing.py
OSainz's picture
Add PR number + Postprocessing
582a8ca
raw
history blame
1.47 kB
def load_file(filename):
with open(filename, 'r') as f:
header = f.readline().strip().split(";")
return header, [line.strip().split(";") for line in f if line.strip()]
def remove_duplicates(data):
keys = set()
_data = []
for item in data:
key = tuple((item[0], item[1], item[2], item[3], item[-1]))
if key in keys:
continue
_data += [item]
keys.add(key)
return _data
def fix_arxiv_links(data):
return [[*item[:-2], item[-2].replace("arxiv.org/pdf", "arxiv.org/abs"), item[-1]] for item in data]
def fix_openreview_links(data):
return [[*item[:-2], item[-2].replace("openreview.net/pdf", "openreview.net/forum"), item[-1]] for item in data]
def sort_data(data):
return sorted(data, key=lambda x: (x[0], x[1], x[2], x[3], x[-1]))
def main():
header, data = load_file("contamination_report.csv")
data = sort_data(data)
data = remove_duplicates(data)
data = fix_arxiv_links(data)
data = fix_openreview_links(data)
print("Total datapoints:", len(data))
with open("contamination_report.csv", 'w') as f:
f.write(";".join(header) + "\n")
past_key = None
for line in data:
key = tuple((line[0], line[1]))
if key != past_key:
f.write("\n")
past_key = key
line = line[:3] + line[3:]
f.write(";".join(line) + "\n")
if __name__ == "__main__":
main()