File size: 3,396 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch

from colbert.utils.utils import torch_load_dnn

from transformers import AutoTokenizer
from colbert.modeling.hf_colbert import HF_ColBERT
from colbert.infra.config import ColBERTConfig


class BaseColBERT(torch.nn.Module):
    """
    Shallow module that wraps the ColBERT parameters, custom configuration, and underlying tokenizer.
    This class provides direct instantiation and saving of the model/colbert_config/tokenizer package.

    Like HF, evaluation mode is the default.
    """

    def __init__(self, name, colbert_config=None):
        super().__init__()

        self.name = name
        self.colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name), colbert_config)
        self.model = HF_ColBERT.from_pretrained(name, colbert_config=self.colbert_config)
        self.raw_tokenizer = AutoTokenizer.from_pretrained(self.model.base)

        self.eval()

    @property
    def device(self):
        return self.model.device

    @property
    def bert(self):
        return self.model.bert

    @property
    def linear(self):
        return self.model.linear
    
    @property
    def score_scaler(self):
        return self.model.score_scaler

    def save(self, path):
        assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format."

        self.model.save_pretrained(path)
        self.raw_tokenizer.save_pretrained(path)

        self.colbert_config.save_for_checkpoint(path)


if __name__ == '__main__':
    import random
    import numpy as np

    from colbert.infra.run import Run
    from colbert.infra.config import RunConfig

    random.seed(12345)
    np.random.seed(12345)
    torch.manual_seed(12345)

    with Run().context(RunConfig(gpus=2)):
        m = BaseColBERT('bert-base-uncased', colbert_config=ColBERTConfig(Run().config, doc_maxlen=300, similarity='l2'))
        m.colbert_config.help()
        print(m.linear.weight)
        m.save('/future/u/okhattab/tmp/2021/08/model.deleteme2/')

    m2 = BaseColBERT('/future/u/okhattab/tmp/2021/08/model.deleteme2/')
    m2.colbert_config.help()
    print(m2.linear.weight)

    exit()

    m = BaseColBERT('/future/u/okhattab/tmp/2021/08/model.deleteme/')
    print('BaseColBERT', m.linear.weight)
    print('BaseColBERT', m.colbert_config)

    exit()

    # m = HF_ColBERT.from_pretrained('nreimers/MiniLMv2-L6-H768-distilled-from-BERT-Large')
    m = HF_ColBERT.from_pretrained('/future/u/okhattab/tmp/2021/08/model.deleteme/')
    print('HF_ColBERT', m.linear.weight)

    m.save_pretrained('/future/u/okhattab/tmp/2021/08/model.deleteme/')

    # old = OldColBERT.from_pretrained('bert-base-uncased')
    # print(old.bert.encoder.layer[10].attention.self.value.weight)

    # random.seed(12345)
    # np.random.seed(12345)
    # torch.manual_seed(12345)

    dnn = torch_load_dnn(
        "/future/u/okhattab/root/TACL21/experiments/Feb26.NQ/train.py/ColBERT.C3/checkpoints/colbert-60000.dnn")
    # base = dnn.get('arguments', {}).get('model', 'bert-base-uncased')

    # new = BaseColBERT.from_pretrained('bert-base-uncased', state_dict=dnn['model_state_dict'])

    # print(new.bert.encoder.layer[10].attention.self.value.weight)

    print(dnn['model_state_dict']['linear.weight'])
    # print(dnn['model_state_dict']['bert.encoder.layer.10.attention.self.value.weight'])

    # # base_model_prefix