File size: 3,696 Bytes
3e88ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json
import os

import torch

from diffusers import UNet1DModel


os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)

os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)


def unet(hor):
    if hor == 128:
        down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
        block_out_channels = (32, 128, 256)
        up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")

    elif hor == 32:
        down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
        block_out_channels = (32, 64, 128, 256)
        up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
    model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
    state_dict = model.state_dict()
    config = dict(
        down_block_types=down_block_types,
        block_out_channels=block_out_channels,
        up_block_types=up_block_types,
        layers_per_block=1,
        use_timestep_embedding=True,
        out_block_type="OutConv1DBlock",
        norm_num_groups=8,
        downsample_each_block=False,
        in_channels=14,
        out_channels=14,
        extra_in_channels=0,
        time_embedding_type="positional",
        flip_sin_to_cos=False,
        freq_shift=1,
        sample_size=65536,
        mid_block_type="MidResTemporalBlock1D",
        act_fn="mish",
    )
    hf_value_function = UNet1DModel(**config)
    print(f"length of state dict: {len(state_dict.keys())}")
    print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
    mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
    for k, v in mapping.items():
        state_dict[v] = state_dict.pop(k)
    hf_value_function.load_state_dict(state_dict)

    torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
    with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
        json.dump(config, f)


def value_function():
    config = dict(
        in_channels=14,
        down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
        up_block_types=(),
        out_block_type="ValueFunction",
        mid_block_type="ValueFunctionMidBlock1D",
        block_out_channels=(32, 64, 128, 256),
        layers_per_block=1,
        downsample_each_block=True,
        sample_size=65536,
        out_channels=14,
        extra_in_channels=0,
        time_embedding_type="positional",
        use_timestep_embedding=True,
        flip_sin_to_cos=False,
        freq_shift=1,
        norm_num_groups=8,
        act_fn="mish",
    )

    model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
    state_dict = model
    hf_value_function = UNet1DModel(**config)
    print(f"length of state dict: {len(state_dict.keys())}")
    print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")

    mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
    for k, v in mapping.items():
        state_dict[v] = state_dict.pop(k)

    hf_value_function.load_state_dict(state_dict)

    torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
    with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
        json.dump(config, f)


if __name__ == "__main__":
    unet(32)
    # unet(128)
    value_function()