File size: 7,668 Bytes
8fe62ee
08acebb
8fe62ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08acebb
8fe62ee
 
 
 
 
08acebb
8fe62ee
 
 
 
 
 
 
 
 
08acebb
 
 
8fe62ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a8a8da
8fe62ee
08acebb
 
 
 
 
 
 
 
 
 
 
 
8fe62ee
 
 
 
 
 
 
08acebb
 
8fe62ee
 
 
 
 
b9f70fd
 
 
 
8fe62ee
 
 
 
 
 
 
 
 
 
 
e8292cf
8fe62ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8292cf
8fe62ee
 
 
 
e8292cf
8fe62ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import torch
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM, BitsAndBytesConfig
from torch import nn
import os
from typing import Optional, List
import os

def kmp_preprocess(pattern):
    pattern_len = len(pattern)
    prefix_suffix = [0] * pattern_len
    j = 0

    for i in range(1, pattern_len):
        while j > 0 and pattern[i] != pattern[j]:
            j = prefix_suffix[j - 1]

        if pattern[i] == pattern[j]:
            j += 1

        prefix_suffix[i] = j

    return prefix_suffix

def kmp_search(text, pattern):
    text_len = len(text)
    pattern_len = len(pattern)
    prefix_suffix = kmp_preprocess(pattern)
    matches = []

    j = 0
    for i in range(text_len):
        while j > 0 and text[i] != pattern[j]:
            j = prefix_suffix[j - 1]

        if text[i] == pattern[j]:
            j += 1

        if j == pattern_len:
            matches.append(i - j + 1)
            j = prefix_suffix[j - 1]

    return matches

class ModelWrapper:
  def __init__(self, model):
    self.model = model

  def __getattr__(self, name):
    return getattr(self.model, name)

  @torch.no_grad()
  def __call__(self, pixel_values):
    return self.model(pixel_values)
  
  def eval(self):
    pass

  def train(self):
    pass

  
  def parameters(self):
    return self.model.parameters()


class CrelloModelConfig(PretrainedConfig):
    def __init__(
        self,
        old_vocab_size: int = 32000,
        vocab_size: int = 32000,
        pad_token_id: int = 2,
        ignore_ids: List[int] = [],
        
        freeze_lm: bool = True, # lm.eval()
        opt_version: str = 'facebook/opt-6.7b',
        
        task: str = 'captioning',
        
        use_lora: bool = False,
        lora_alpha: int = 32,
        lora_r: int = 8,
        lora_dropout: float = 0.05,
        lora_target_modules: str = r'.*\.(q_proj|v_proj)',
        
        hidden_size: int = -1,
        load_in_4bit: Optional[bool] = False,
        
        **kwargs,
    ):
        super().__init__(**kwargs)
        assert old_vocab_size > 0, 'old_vocab_size must be positive'
        assert vocab_size > 0, 'vocab_size must be positive'

        self.old_vocab_size = old_vocab_size
        self.vocab_size = vocab_size
        self.pad_token_id = pad_token_id
        self.freeze_lm = freeze_lm
        self.opt_version = opt_version
        self.task = task
        self.use_lora = use_lora
        self.lora_alpha = lora_alpha
        self.lora_r = lora_r
        self.lora_dropout = lora_dropout
        self.lora_target_modules = lora_target_modules
        self.hidden_size = hidden_size
        self.load_in_4bit = load_in_4bit
        self.ignore_ids = ignore_ids


class CrelloModel(PreTrainedModel):
  config_class = CrelloModelConfig
  supports_gradient_checkpointing = True

  def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
    self.lm.gradient_checkpointing_enable()

  def __init__(self, config: CrelloModelConfig): # 显示声明config类型
    super().__init__(config)
    use_auth_token = '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'

    self.pad_token_id = config.pad_token_id

    self.args = config

    opt_version = config.opt_version

    print(f"Using {opt_version} for the language model.")

    if 'facebook/opt' in opt_version:
      self.lm = OPTForCausalLM.from_pretrained(opt_version)
      word_embed_proj_dim = self.lm.config.word_embed_proj_dim
    else:
      if config.load_in_4bit:
        print("\n would load_in_4bit")
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=config.load_in_4bit
        )
        # This means: fit the entire model on the GPU:0
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        device_map = {"": local_rank}
        torch_dtype = torch.bfloat16
      else:
        print("\n wouldn't load_in_4bit")
        quantization_config = None
        device_map = None
        torch_dtype = None

      self.lm = AutoModelForCausalLM.from_pretrained(
        "WYBar/LLM_For_Layout_Planning",
        subfolder="Meta-Llama-3-8B",
        # use_auth_token=use_auth_token,
        # quantization_config=quantization_config,
        # device_map=device_map,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
      )
      # self.lm = AutoModelForCausalLM.from_pretrained(
      #   opt_version,
      #   use_auth_token=use_auth_token,
      #   quantization_config=quantization_config,
      #   device_map=device_map,
      #   trust_remote_code=True,
      #   # attn_implementation="flash_attention_2",
      #   # flash_attn=True,
      #   # flash_rotary=True,
      #   # fused_dense=True,
      #   torch_dtype=torch.bfloat16,
      # )
      word_embed_proj_dim = self.lm.config.hidden_size
    self.config.hidden_size = self.lm.config.hidden_size
    self.opt_version = opt_version

    if self.args.freeze_lm:
      self.lm.eval()
      print("Freezing the LM.")
      for param in self.lm.parameters():
        param.requires_grad = False
    else:
      print("\n no freeze lm, so to train lm")
      self.lm.train()
      self.lm.config.gradient_checkpointing = True

    print('resize token embeddings to match the tokenizer', config.vocab_size)
    self.lm.resize_token_embeddings(config.vocab_size)
    self.input_embeddings = self.lm.get_input_embeddings()
    print('after token embeddings to match the tokenizer', config.vocab_size)
    
  def train(self, mode=True):
    super().train(mode=mode)
    # Overwrite train() to ensure frozen models remain frozen.
    if self.args.freeze_lm:
      self.lm.eval()

  def forward(
    self,
    labels: torch.LongTensor,
  ):
    print("inside Crello")
    batch_size = labels.shape[0]
    full_labels = labels.detach().clone()

    input_embs = self.input_embeddings(labels)  # (N, T, D)
    input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean()

    for ignore_id in self.config.ignore_ids:
      full_labels[full_labels == ignore_id] = -100

    pad_idx = []
    # 获取每一个batch的 seq 长度,取值为 max_len or padding_position,记录在pad_idx
    # -100 is the ignore index for cross entropy loss. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
    for label in full_labels:
      for k, token in enumerate(label):
        # Mask out pad tokens if they exist.
        if token in [self.pad_token_id]:
          label[k:] = -100 # 将后面的token都mask掉
          pad_idx.append(k)
          break
        if k == len(label) - 1:  # No padding found.
          pad_idx.append(k + 1)
    assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
    
    print("inside Crello, lm1")
    output = self.lm( inputs_embeds=input_embs,
                      # input_ids=labels,
                      labels=full_labels,
                      output_hidden_states=True)
    print("inside Crello, lm2")
    
    return output, full_labels, input_embs_norm

if __name__=="__main__":
  config = CrelloModelConfig(
    vocab_size=50265,
    image_reg_token=50264,
    image_gt_token=50263,
  )
  print("config: ",config)
  model1 = CrelloModel(config)
  print("\nmodel1: ",model1)
  model1.save_pretrained('test')
  model2 = CrelloModel.from_pretrained('test')
  print("\nmodel2: ",model2)
  # compare model1 and model2

  state_dict1 = model1.state_dict()
  state_dict2 = model2.state_dict()
  assert set(state_dict1.keys()) == set(state_dict2.keys())
  for k in state_dict1.keys():
    assert torch.equal(state_dict1[k], state_dict2[k])
  print('all parameters are equal')