File size: 2,111 Bytes
50efa30
 
 
 
 
 
 
e0c3c75
 
50efa30
e0c3c75
 
50efa30
e0c3c75
 
 
50efa30
e0c3c75
 
 
 
 
 
 
 
50efa30
 
 
e0c3c75
 
50efa30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c3c75
50efa30
e0c3c75
50efa30
 
e0c3c75
50efa30
 
 
 
 
 
 
 
 
 
 
 
 
e0c3c75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import wandb

from PIL import Image
from PIL.Image import Transpose, Resampling


# api = wandb.Api()
# run = api.run("acozma/cs581/5ttfkav8")

# print(run.summary)
# print("Downloading images...")

# for file in run.files():
#     if file.name.endswith(".png"):
#         file.download(exist_ok=True)

# print("Finished downloading images")


folder_path = "./imgs/"
policy_file_prefix = "Pi"
q_file_prefix = "Q"
out_name = "qtable_policy"
sort_lambda = lambda x: int(x.split("_")[1].split(".")[0])  # key used to sort image filenames


def process_images(image_fnames, upscale=20):
    print(image_fnames)
    image_fnames.sort(key=sort_lambda)
    frames = [Image.open(image) for image in image_fnames]
    frames = [frame.transpose(Transpose.ROTATE_90) for frame in frames]
    frames = [
        frame.resize(
            (frame.size[0] * upscale, frame.size[1] * upscale),
            resample=Resampling.NEAREST,
        )
        for frame in frames
    ]
    return frames


def images_to_gif(frames, fname, duration=500):
    print(f"Creating gif: {fname}")

    frame_one = frames[0]
    frame_one.save(
        f"{fname}.gif",
        format="GIF",
        append_images=frames,
        save_all=True,
        duration=duration,
        loop=0,
    )


all_fnames = [os.path.join(folder_path, f) for f in os.listdir(folder_path)]
print(all_fnames)

fnames_policy = [f for f in all_fnames if os.path.basename(f).startswith(policy_file_prefix)))]
policy_frames = process_images(fnames_policy)

fnames_qtable = [f for f in all_fnames if os.path.basename(f).startswith(q_file_prefix)]
qtable_frames = process_images(fnames_qtable)

spacing_factor = 1 / 2

final_frames = []
for i, (qtable, policy) in enumerate(zip(qtable_frames, policy_frames)):
    width, height = qtable.size
    final_height = int(height * 2 + height * spacing_factor)
    new_frame = Image.new("RGB", (width, final_height), color="white")
    new_frame.paste(qtable, (0, 0))
    new_frame.paste(policy, (0, height + int(height * spacing_factor)))
    final_frames.append(new_frame)

images_to_gif(final_frames, out_name)