File size: 1,871 Bytes
ad552d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import csv
import shutil
import urllib.request


def get_score(
    model_name=None,
    dataset_name=None,
    dataset_res=None,
    dataset_split=None,
    task_name=None,
):
    # download the csv file from server
    url = "https://www.cs.cmu.edu/~clean-fid/files/leaderboard.csv"
    local_path = "/tmp/leaderboard.csv"
    with urllib.request.urlopen(url) as response, open(local_path, "wb") as f:
        shutil.copyfileobj(response, f)

    d_field2idx = {}
    l_matches = []
    with open(local_path, "r") as f:
        csvreader = csv.reader(f)
        l_fields = next(csvreader)
        for idx, val in enumerate(l_fields):
            d_field2idx[val.strip()] = idx
        # iterate through all rows
        for row in csvreader:
            # skip empty rows
            if len(row) == 0:
                continue
            # skip if the filter doesn't match
            if model_name is not None and (
                row[d_field2idx["model_name"]].strip() != model_name
            ):
                continue
            if dataset_name is not None and (
                row[d_field2idx["dataset_name"]].strip() != dataset_name
            ):
                continue
            if dataset_res is not None and (
                row[d_field2idx["dataset_res"]].strip() != dataset_res
            ):
                continue
            if dataset_split is not None and (
                row[d_field2idx["dataset_split"]].strip() != dataset_split
            ):
                continue
            if task_name is not None and (
                row[d_field2idx["task_name"]].strip() != task_name
            ):
                continue
            curr = {}
            for f in l_fields:
                curr[f.strip()] = row[d_field2idx[f.strip()]].strip()
            l_matches.append(curr)
    os.remove(local_path)
    return l_matches