File size: 2,038 Bytes
d323598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import annotations

import os

import torch
from safetensors.torch import save_file

ckpt = "path_to/pytorch_model.bin"

vista_bin = torch.load(ckpt, map_location="cpu")  # only contains model weights

for k in list(vista_bin.keys()):  # merge LoRA weights (if exist) for inference
    if "adapter_down" in k:
        if "q_adapter_down" in k:
            up_k = k.replace("q_adapter_down", "q_adapter_up")
            pretrain_k = k.replace("q_adapter_down", "to_q")
        elif "k_adapter_down" in k:
            up_k = k.replace("k_adapter_down", "k_adapter_up")
            pretrain_k = k.replace("k_adapter_down", "to_k")
        elif "v_adapter_down" in k:
            up_k = k.replace("v_adapter_down", "v_adapter_up")
            pretrain_k = k.replace("v_adapter_down", "to_v")
        else:
            up_k = k.replace("out_adapter_down", "out_adapter_up")
            if "model_ema" in k:
                pretrain_k = k.replace("out_adapter_down", "to_out0")
            else:
                pretrain_k = k.replace("out_adapter_down", "to_out.0")

        lora_weights = vista_bin[up_k] @ vista_bin[k]
        del vista_bin[k]
        del vista_bin[up_k]
        vista_bin[pretrain_k] = vista_bin[pretrain_k] + lora_weights

for k in list(vista_bin.keys()):  # remove the prefix
    if "_forward_module" in k and "decay" not in k and "num_updates" not in k:
        vista_bin[k.replace("_forward_module.", "")] = vista_bin[k]
    del vista_bin[k]

for k in list(vista_bin.keys()):  # combine EMA weights
    if "model_ema" in k:
        orig_k = None
        for kk in list(vista_bin.keys()):
            if "model_ema" not in kk and k[10:] == kk[6:].replace(".", ""):
                orig_k = kk
        assert orig_k is not None
        vista_bin[orig_k] = vista_bin[k]
        del vista_bin[k]
        print("Replace", orig_k, "with", k)

vista_st = dict()
for k in list(vista_bin.keys()):
    vista_st[k] = vista_bin[k]

os.makedirs("ckpts", exist_ok=True)
save_file(vista_st, "ckpts/vista.safetensors")