Spaces:
Running
Running
import json | |
import argparse | |
import json | |
from utils import * | |
from dataset import * | |
from metrics import * | |
from compute_correlations import compute_flickr | |
from compute_pascal50s import compute_pascal50S | |
from compute_foil import compute_foil | |
def collect_coef(memory, dataset_name, method, coef_tensor): | |
memory.setdefault(dataset_name, {}) | |
coef = {k : round(float(v.numpy() if not isinstance(v,float) else v),4) for k, v in coef_tensor.items()} | |
memory[dataset_name].update({method : coef}) | |
gprint(f"[{dataset_name}]",method,coef) | |
def compute_coef(args,memory,tops): | |
dataset_name = "test" | |
path = f"data_en/polaris/polaris_{dataset_name}.csv" | |
yprint(f"Processing {dataset_name} ... (path: {path})") | |
test_dataset = get_dataset(path) | |
# mypolos | |
if args.polos: | |
polos_coef = compute_polos_coef(args,test_dataset,dataset_name,kendall_type='c') | |
collect_coef(memory, dataset_name, "Polos", polos_coef) | |
return memory, tops | |
def main(args): | |
memory, tops = {}, {} | |
if args.flickr: | |
memory, tops = compute_flickr(args,args.model,memory,tops) | |
if args.coef: | |
memory, tops = compute_coef(args, memory, tops) | |
if args.pascal: | |
memory, tops = compute_pascal50S(args, memory, tops) | |
if args.foil: | |
memory, tops = compute_foil(args, memory, tops) | |
with open("zeroshot_test_results.json", "w") as f: | |
json.dump(memory, f, indent=4) | |
yprint("[RESULTS]") | |
gprint(json.dumps(memory, indent=4)) | |
rprint("[TOP]") | |
for dataset_name, values in tops.items(): | |
rprint(f"> {dataset_name}") | |
if isinstance(values,dict): # coef | |
for kind, coef in values.items(): | |
rprint(f"{kind}: {coef[0]} ({coef[1]})") | |
else: # acc | |
method, acc = values | |
rprint(f"{method} ({acc})") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# models | |
parser.add_argument('--model', default=None) | |
parser.add_argument('--hparams',default=None) | |
parser.add_argument('--polos', action='store_true') | |
# benchmarks | |
parser.add_argument('--coef', action='store_true') | |
parser.add_argument('--flickr', action='store_true') | |
parser.add_argument('--pascal', action='store_true') | |
parser.add_argument('--foil', action='store_true') | |
args = parser.parse_args() | |
main(args) | |