patrickvonplaten commited on
Commit
56fb00e
·
1 Parent(s): e402ae6
Files changed (1) hide show
  1. update_almost_agi.py +80 -0
update_almost_agi.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+
4
+ from collections import Counter
5
+ from datasets import load_dataset
6
+ from huggingface_hub import HfApi, list_datasets
7
+
8
+
9
+ api = HfApi(token=os.environ.get("HF_TOKEN", None))
10
+ def restart_space():
11
+ api.restart_space(repo_id="OpenGenAI/parti-prompts-leaderboard")
12
+
13
+ parti_prompt_results = []
14
+ ORG = "diffusers-parti-prompts"
15
+ SUBMISSIONS = {
16
+ "kand2": None,
17
+ "sdxl": None,
18
+ "wuerst": None,
19
+ "karlo": None,
20
+ }
21
+ LINKS = {
22
+ "kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder",
23
+ "sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
24
+ "wuerst": "https://huggingface.co/warp-ai/wuerstchen",
25
+ "karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha",
26
+ }
27
+ MODEL_KEYS = "-".join(SUBMISSIONS.keys())
28
+ SUBMISSION_ORG = f"result-{MODEL_KEYS}"
29
+
30
+ submission_names = list(SUBMISSIONS.keys())
31
+
32
+ ORG = "diffusers-parti-prompts"
33
+ SUBMISSIONS = {
34
+ "kand2": load_dataset(os.path.join(ORG, "kandinsky-2-2"))["train"],
35
+ "sdxl": load_dataset(os.path.join(ORG, "sdxl-1.0-refiner"))["train"],
36
+ "wuerst": load_dataset(os.path.join(ORG, "wuerstchen"))["train"],
37
+ "karlo": load_dataset(os.path.join(ORG, "karlo-v1"))["train"],
38
+ }
39
+ ds = load_dataset("nateraw/parti-prompts")["train"]
40
+
41
+ parti_prompt_categories = ds["Category"]
42
+ parti_prompt_challenge = ds["Challenge"]
43
+
44
+ UPLOAD_ORG = "almost-agi-diff"
45
+
46
+ def load_non_solved():
47
+ all_datasets = list_datasets(author=SUBMISSION_ORG)
48
+ relevant_ids = [d.id for d in all_datasets]
49
+
50
+ all_non_solved_image_ids = []
51
+
52
+ for _id in relevant_ids[:5]:
53
+ try:
54
+ ds = load_dataset(_id)["train"]
55
+ except:
56
+ continue
57
+
58
+ for result, image_id in zip(ds["result"], ds["id"]):
59
+ if result == "":
60
+ all_non_solved_image_ids.append(image_id)
61
+
62
+ all_non_solved_image_ids_dict = Counter(all_non_solved_image_ids)
63
+ all_non_solved_image_ids = list(all_non_solved_image_ids_dict.keys())
64
+ all_non_solved_image_votes = list(all_non_solved_image_ids_dict.values())
65
+
66
+ return all_non_solved_image_ids, all_non_solved_image_votes
67
+
68
+ def main():
69
+ non_solved_ids, upvotes = load_non_solved()
70
+
71
+ for name, ds in SUBMISSIONS.items():
72
+ ds_to_push = ds.select(non_solved_ids)
73
+
74
+ votes_column = upvotes
75
+
76
+ ds_to_push.add_column("upvotes", votes_column)
77
+ sorted_ds = ds_to_push.sort("upvotes", reverse=True)
78
+
79
+ import ipdb; ipdb.set_trace()
80
+ sorted_ds.push_to_hub(f"{UPLOAD_ORG}/{name}")