English
naveensp commited on
Commit
fbfa1a7
·
verified ·
1 Parent(s): 8e511a7

Delete folder model with huggingface_hub

Browse files
Files changed (47) hide show
  1. model/__init__.py +0 -6
  2. model/__pycache__/__init__.cpython-310.pyc +0 -0
  3. model/__pycache__/__init__.cpython-311.pyc +0 -0
  4. model/__pycache__/__init__.cpython-312.pyc +0 -0
  5. model/__pycache__/__init__.pypy39.pyc +0 -0
  6. model/__pycache__/builder.cpython-311.pyc +0 -0
  7. model/__pycache__/builder.cpython-312.pyc +0 -0
  8. model/__pycache__/llava_arch.cpython-310.pyc +0 -0
  9. model/__pycache__/llava_arch.cpython-311.pyc +0 -0
  10. model/__pycache__/llava_arch.cpython-312.pyc +0 -0
  11. model/apply_delta.py +0 -48
  12. model/builder.py +0 -191
  13. model/consolidate.py +0 -29
  14. model/download_pretrain_dataset.py +0 -5
  15. model/language_model/__pycache__/llava_llama.cpython-310.pyc +0 -0
  16. model/language_model/__pycache__/llava_llama.cpython-311.pyc +0 -0
  17. model/language_model/__pycache__/llava_llama.cpython-312.pyc +0 -0
  18. model/language_model/__pycache__/llava_llama.pypy39.pyc +0 -0
  19. model/language_model/__pycache__/llava_mistral.cpython-310.pyc +0 -0
  20. model/language_model/__pycache__/llava_mistral.cpython-311.pyc +0 -0
  21. model/language_model/__pycache__/llava_mistral.cpython-312.pyc +0 -0
  22. model/language_model/__pycache__/llava_mpt.cpython-310.pyc +0 -0
  23. model/language_model/__pycache__/llava_mpt.cpython-311.pyc +0 -0
  24. model/language_model/__pycache__/llava_mpt.cpython-312.pyc +0 -0
  25. model/language_model/__pycache__/llava_olmo1p58b.cpython-310.pyc +0 -0
  26. model/language_model/__pycache__/llava_olmo1p58b.cpython-311.pyc +0 -0
  27. model/language_model/__pycache__/llava_olmo1p58b.cpython-312.pyc +0 -0
  28. model/language_model/llava_llama.py +0 -158
  29. model/language_model/llava_mistral.py +0 -158
  30. model/language_model/llava_mpt.py +0 -97
  31. model/language_model/llava_olmo.py +0 -115
  32. model/language_model/llava_olmo1p58b.py +0 -164
  33. model/llava_arch.py +0 -369
  34. model/make_delta.py +0 -52
  35. model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  36. model/multimodal_encoder/__pycache__/builder.cpython-311.pyc +0 -0
  37. model/multimodal_encoder/__pycache__/builder.cpython-312.pyc +0 -0
  38. model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  39. model/multimodal_encoder/__pycache__/clip_encoder.cpython-311.pyc +0 -0
  40. model/multimodal_encoder/__pycache__/clip_encoder.cpython-312.pyc +0 -0
  41. model/multimodal_encoder/builder.py +0 -15
  42. model/multimodal_encoder/clip_encoder.py +0 -147
  43. model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  44. model/multimodal_projector/__pycache__/builder.cpython-311.pyc +0 -0
  45. model/multimodal_projector/__pycache__/builder.cpython-312.pyc +0 -0
  46. model/multimodal_projector/builder.py +0 -51
  47. model/utils.py +0 -20
model/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- try:
2
- from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3
- from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4
- from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5
- except:
6
- pass
 
 
 
 
 
 
 
model/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (457 Bytes)
 
model/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (562 Bytes)
 
model/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (486 Bytes)
 
model/__pycache__/__init__.pypy39.pyc DELETED
Binary file (455 Bytes)
 
model/__pycache__/builder.cpython-311.pyc DELETED
Binary file (12.1 kB)
 
model/__pycache__/builder.cpython-312.pyc DELETED
Binary file (10.8 kB)
 
model/__pycache__/llava_arch.cpython-310.pyc DELETED
Binary file (10.7 kB)
 
model/__pycache__/llava_arch.cpython-311.pyc DELETED
Binary file (23.3 kB)
 
model/__pycache__/llava_arch.cpython-312.pyc DELETED
Binary file (20 kB)
 
model/apply_delta.py DELETED
@@ -1,48 +0,0 @@
1
- """
2
- Usage:
3
- python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
- """
5
- import argparse
6
-
7
- import torch
8
- from tqdm import tqdm
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
- from llava import LlavaLlamaForCausalLM
11
-
12
-
13
- def apply_delta(base_model_path, target_model_path, delta_path):
14
- print("Loading base model")
15
- base = AutoModelForCausalLM.from_pretrained(
16
- base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
-
18
- print("Loading delta")
19
- delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
- delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
-
22
- print("Applying delta")
23
- for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
- if name not in base.state_dict():
25
- assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
- continue
27
- if param.data.shape == base.state_dict()[name].shape:
28
- param.data += base.state_dict()[name]
29
- else:
30
- assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
- f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
- bparam = base.state_dict()[name]
33
- param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
-
35
- print("Saving target model")
36
- delta.save_pretrained(target_model_path)
37
- delta_tokenizer.save_pretrained(target_model_path)
38
-
39
-
40
- if __name__ == "__main__":
41
- parser = argparse.ArgumentParser()
42
- parser.add_argument("--base-model-path", type=str, required=True)
43
- parser.add_argument("--target-model-path", type=str, required=True)
44
- parser.add_argument("--delta-path", type=str, required=True)
45
-
46
- args = parser.parse_args()
47
-
48
- apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/builder.py DELETED
@@ -1,191 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import os
17
- import warnings
18
- import shutil
19
-
20
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
- import torch
22
- from llava.model import *
23
- from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
- import json
25
- import llava.model.language_model.llava_olmo1p58b as llava_olmo
26
-
27
-
28
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
29
- kwargs = {"device_map": device_map, **kwargs}
30
-
31
- if device != "cuda":
32
- kwargs['device_map'] = {"": device}
33
-
34
- if load_8bit:
35
- kwargs['load_in_8bit'] = True
36
- elif load_4bit:
37
- kwargs['load_in_4bit'] = True
38
- kwargs['quantization_config'] = BitsAndBytesConfig(
39
- load_in_4bit=True,
40
- bnb_4bit_compute_dtype=torch.float16,
41
- bnb_4bit_use_double_quant=True,
42
- bnb_4bit_quant_type='nf4'
43
- )
44
- else:
45
- kwargs['torch_dtype'] = torch.float16
46
-
47
- if use_flash_attn:
48
- kwargs['attn_implementation'] = 'flash_attention_2'
49
-
50
- if 'llava' in model_name.lower() and 'olmo' not in model_name.lower():
51
- # Load LLaVA model
52
- if 'lora' in model_name.lower() and model_base is None:
53
- warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
54
- if 'lora' in model_name.lower() and model_base is not None:
55
- from llava.model.language_model.llava_llama import LlavaConfig
56
- lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
57
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
58
- print('Loading LLaVA from base model...')
59
- model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
60
- token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
61
- if model.lm_head.weight.shape[0] != token_num:
62
- model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
63
- model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
64
-
65
- print('Loading additional LLaVA weights...')
66
- if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
67
- non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
68
- else:
69
- # this is probably from HF Hub
70
- from huggingface_hub import hf_hub_download
71
- def load_from_hf(repo_id, filename, subfolder=None):
72
- cache_file = hf_hub_download(
73
- repo_id=repo_id,
74
- filename=filename,
75
- subfolder=subfolder)
76
- return torch.load(cache_file, map_location='cpu')
77
- non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
78
- non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
79
- if any(k.startswith('model.model.') for k in non_lora_trainables):
80
- non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
81
- model.load_state_dict(non_lora_trainables, strict=False)
82
-
83
- from peft import PeftModel
84
- print('Loading LoRA weights...')
85
- model = PeftModel.from_pretrained(model, model_path)
86
- print('Merging LoRA weights...')
87
- model = model.merge_and_unload()
88
- print('Model is loaded...')
89
- elif model_base is not None:
90
- # this may be mm projector only
91
- print('Loading LLaVA from base model...')
92
- if 'mpt' in model_name.lower():
93
- if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
94
- shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
95
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
96
- cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
97
- model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
98
- else:
99
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
100
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
101
- model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
102
-
103
- mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
104
- mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
105
- model.load_state_dict(mm_projector_weights, strict=False)
106
- else:
107
- if 'mpt' in model_name.lower():
108
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
109
- model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
110
- elif 'mistral' in model_name.lower():
111
- tokenizer = AutoTokenizer.from_pretrained(model_path)
112
- model = LlavaMistralForCausalLM.from_pretrained(
113
- model_path,
114
- low_cpu_mem_usage=True,
115
- **kwargs
116
- )
117
- else:
118
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
119
- model = LlavaLlamaForCausalLM.from_pretrained(
120
- model_path,
121
- low_cpu_mem_usage=True,
122
- **kwargs
123
- )
124
- elif 'llava' in model_name.lower() and 'olmo' in model_name.lower():
125
- # Newly introduced case through Olmollavabitnet1B where we load the model params from pretrained path. output must be tokenizer and model
126
- print('Setting up LLaVaOLMOBitnet1B for eval........')
127
- with open('checkpoints/llava-LlavaOLMoBitnet1B-Run2-finetune/config.json') as json_file:
128
- data = json.load(json_file)
129
-
130
- config_class = llava_olmo.LlavaOLMoBitnet1BConfig(**data)
131
- model = llava_olmo.LlavaOLMoBitnet1BForCausalLM(config_class).to(device)
132
- model.model.vision_tower.load_model()
133
- weight_checkpoint = torch.load('checkpoints/llava-LlavaOLMoBitnet1B-Run3-finetune/pytorch_model.bin')
134
- model.load_state_dict(weight_checkpoint)
135
-
136
- tokenizer = AutoTokenizer.from_pretrained(
137
- "NousResearch/OLMo-Bitnet-1B",
138
- model_max_length=2048,
139
- padding_side="right",
140
- pad_token_id=1,
141
- use_fast=True,
142
- legacy=False,
143
- unk_token='<|padding|>',
144
- )
145
-
146
- else:
147
- # Load language model
148
- if model_base is not None:
149
- # PEFT model
150
- from peft import PeftModel
151
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
152
- model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
153
- print(f"Loading LoRA weights from {model_path}")
154
- model = PeftModel.from_pretrained(model, model_path)
155
- print(f"Merging weights")
156
- model = model.merge_and_unload()
157
- print('Convert to FP16...')
158
- model.to(torch.float16)
159
- else:
160
- use_fast = False
161
- if 'mpt' in model_name.lower():
162
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
163
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
164
- else:
165
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
166
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
167
-
168
- image_processor = None
169
-
170
- if 'llava' in model_name.lower():
171
- mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
172
- mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
173
- if mm_use_im_patch_token:
174
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
175
- if mm_use_im_start_end:
176
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
177
- model.resize_token_embeddings(len(tokenizer))
178
-
179
- vision_tower = model.get_vision_tower()
180
- if not vision_tower.is_loaded:
181
- vision_tower.load_model(device_map=device_map)
182
- if device_map != 'auto':
183
- vision_tower.to(device=device_map, dtype=torch.float16)
184
- image_processor = vision_tower.image_processor
185
-
186
- if hasattr(model.config, "max_sequence_length"):
187
- context_len = model.config.max_sequence_length
188
- else:
189
- context_len = 2048
190
-
191
- return tokenizer, model, image_processor, context_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/consolidate.py DELETED
@@ -1,29 +0,0 @@
1
- """
2
- Usage:
3
- python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
- """
5
- import argparse
6
-
7
- import torch
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
- from llava.model import *
10
- from llava.model.utils import auto_upgrade
11
-
12
-
13
- def consolidate_ckpt(src_path, dst_path):
14
- print("Loading model")
15
- auto_upgrade(src_path)
16
- src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
- src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
- src_model.save_pretrained(dst_path)
19
- src_tokenizer.save_pretrained(dst_path)
20
-
21
-
22
- if __name__ == "__main__":
23
- parser = argparse.ArgumentParser()
24
- parser.add_argument("--src", type=str, required=True)
25
- parser.add_argument("--dst", type=str, required=True)
26
-
27
- args = parser.parse_args()
28
-
29
- consolidate_ckpt(args.src, args.dst)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/download_pretrain_dataset.py DELETED
@@ -1,5 +0,0 @@
1
- from datasets import load_dataset
2
-
3
- dataset = load_dataset("liuhaotian/LLaVA-Pretrain")
4
- dataset.save_to_disk("LLaVa_Pretrain_dataset")
5
-
 
 
 
 
 
 
model/language_model/__pycache__/llava_llama.cpython-310.pyc DELETED
Binary file (3.8 kB)
 
model/language_model/__pycache__/llava_llama.cpython-311.pyc DELETED
Binary file (6.31 kB)
 
model/language_model/__pycache__/llava_llama.cpython-312.pyc DELETED
Binary file (5.53 kB)
 
model/language_model/__pycache__/llava_llama.pypy39.pyc DELETED
Binary file (4.3 kB)
 
model/language_model/__pycache__/llava_mistral.cpython-310.pyc DELETED
Binary file (3.84 kB)
 
model/language_model/__pycache__/llava_mistral.cpython-311.pyc DELETED
Binary file (6.32 kB)
 
model/language_model/__pycache__/llava_mistral.cpython-312.pyc DELETED
Binary file (5.5 kB)
 
model/language_model/__pycache__/llava_mpt.cpython-310.pyc DELETED
Binary file (3.17 kB)
 
model/language_model/__pycache__/llava_mpt.cpython-311.pyc DELETED
Binary file (5.01 kB)
 
model/language_model/__pycache__/llava_mpt.cpython-312.pyc DELETED
Binary file (4.34 kB)
 
model/language_model/__pycache__/llava_olmo1p58b.cpython-310.pyc DELETED
Binary file (4 kB)
 
model/language_model/__pycache__/llava_olmo1p58b.cpython-311.pyc DELETED
Binary file (6.6 kB)
 
model/language_model/__pycache__/llava_olmo1p58b.cpython-312.pyc DELETED
Binary file (5.69 kB)
 
model/language_model/llava_llama.py DELETED
@@ -1,158 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
-
21
- from transformers import AutoConfig, AutoModelForCausalLM, \
22
- LlamaConfig, LlamaModel, LlamaForCausalLM
23
-
24
- from transformers.modeling_outputs import CausalLMOutputWithPast
25
- from transformers.generation.utils import GenerateOutput
26
-
27
- from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
-
29
-
30
- class LlavaConfig(LlamaConfig):
31
- model_type = "llava_llama"
32
-
33
-
34
- class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
- config_class = LlavaConfig
36
-
37
- def __init__(self, config: LlamaConfig):
38
- super(LlavaLlamaModel, self).__init__(config)
39
-
40
-
41
- class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
- config_class = LlavaConfig
43
-
44
- def __init__(self, config):
45
- super(LlamaForCausalLM, self).__init__(config)
46
- self.model = LlavaLlamaModel(config)
47
- self.pretraining_tp = config.pretraining_tp
48
- self.vocab_size = config.vocab_size
49
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
-
51
- # Initialize weights and apply final processing
52
- self.post_init()
53
-
54
- def get_model(self):
55
- return self.model
56
-
57
- def forward(
58
- self,
59
- input_ids: torch.LongTensor = None,
60
- attention_mask: Optional[torch.Tensor] = None,
61
- position_ids: Optional[torch.LongTensor] = None,
62
- past_key_values: Optional[List[torch.FloatTensor]] = None,
63
- inputs_embeds: Optional[torch.FloatTensor] = None,
64
- labels: Optional[torch.LongTensor] = None,
65
- use_cache: Optional[bool] = None,
66
- output_attentions: Optional[bool] = None,
67
- output_hidden_states: Optional[bool] = None,
68
- images: Optional[torch.FloatTensor] = None,
69
- image_sizes: Optional[List[List[int]]] = None,
70
- return_dict: Optional[bool] = None,
71
- ) -> Union[Tuple, CausalLMOutputWithPast]:
72
-
73
- if inputs_embeds is None:
74
- (
75
- input_ids,
76
- position_ids,
77
- attention_mask,
78
- past_key_values,
79
- inputs_embeds,
80
- labels
81
- ) = self.prepare_inputs_labels_for_multimodal(
82
- input_ids,
83
- position_ids,
84
- attention_mask,
85
- past_key_values,
86
- labels,
87
- images,
88
- image_sizes
89
- )
90
-
91
- return super().forward(
92
- input_ids=input_ids,
93
- attention_mask=attention_mask,
94
- position_ids=position_ids,
95
- past_key_values=past_key_values,
96
- inputs_embeds=inputs_embeds,
97
- labels=labels,
98
- use_cache=use_cache,
99
- output_attentions=output_attentions,
100
- output_hidden_states=output_hidden_states,
101
- return_dict=return_dict
102
- )
103
-
104
- @torch.no_grad()
105
- def generate(
106
- self,
107
- inputs: Optional[torch.Tensor] = None,
108
- images: Optional[torch.Tensor] = None,
109
- image_sizes: Optional[torch.Tensor] = None,
110
- **kwargs,
111
- ) -> Union[GenerateOutput, torch.LongTensor]:
112
- position_ids = kwargs.pop("position_ids", None)
113
- attention_mask = kwargs.pop("attention_mask", None)
114
- if "inputs_embeds" in kwargs:
115
- raise NotImplementedError("`inputs_embeds` is not supported")
116
-
117
- if images is not None:
118
- (
119
- inputs,
120
- position_ids,
121
- attention_mask,
122
- _,
123
- inputs_embeds,
124
- _
125
- ) = self.prepare_inputs_labels_for_multimodal(
126
- inputs,
127
- position_ids,
128
- attention_mask,
129
- None,
130
- None,
131
- images,
132
- image_sizes=image_sizes
133
- )
134
- else:
135
- inputs_embeds = self.get_model().embed_tokens(inputs)
136
-
137
- return super().generate(
138
- position_ids=position_ids,
139
- attention_mask=attention_mask,
140
- inputs_embeds=inputs_embeds,
141
- **kwargs
142
- )
143
-
144
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
- inputs_embeds=None, **kwargs):
146
- images = kwargs.pop("images", None)
147
- image_sizes = kwargs.pop("image_sizes", None)
148
- inputs = super().prepare_inputs_for_generation(
149
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
- )
151
- if images is not None:
152
- inputs['images'] = images
153
- if image_sizes is not None:
154
- inputs['image_sizes'] = image_sizes
155
- return inputs
156
-
157
- AutoConfig.register("llava_llama", LlavaConfig)
158
- AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/language_model/llava_mistral.py DELETED
@@ -1,158 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- from torch.nn import CrossEntropyLoss
21
-
22
- from transformers import AutoConfig, AutoModelForCausalLM, \
23
- MistralConfig, MistralModel, MistralForCausalLM
24
-
25
- from transformers.modeling_outputs import CausalLMOutputWithPast
26
- from transformers.generation.utils import GenerateOutput
27
-
28
- from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
-
30
-
31
- class LlavaMistralConfig(MistralConfig):
32
- model_type = "llava_mistral"
33
-
34
-
35
- class LlavaMistralModel(LlavaMetaModel, MistralModel):
36
- config_class = LlavaMistralConfig
37
-
38
- def __init__(self, config: MistralConfig):
39
- super(LlavaMistralModel, self).__init__(config)
40
-
41
-
42
- class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43
- config_class = LlavaMistralConfig
44
-
45
- def __init__(self, config):
46
- super(MistralForCausalLM, self).__init__(config)
47
- self.model = LlavaMistralModel(config)
48
-
49
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
-
51
- # Initialize weights and apply final processing
52
- self.post_init()
53
-
54
- def get_model(self):
55
- return self.model
56
-
57
- def forward(
58
- self,
59
- input_ids: torch.LongTensor = None,
60
- attention_mask: Optional[torch.Tensor] = None,
61
- position_ids: Optional[torch.LongTensor] = None,
62
- past_key_values: Optional[List[torch.FloatTensor]] = None,
63
- inputs_embeds: Optional[torch.FloatTensor] = None,
64
- labels: Optional[torch.LongTensor] = None,
65
- use_cache: Optional[bool] = None,
66
- output_attentions: Optional[bool] = None,
67
- output_hidden_states: Optional[bool] = None,
68
- images: Optional[torch.FloatTensor] = None,
69
- image_sizes: Optional[List[List[int]]] = None,
70
- return_dict: Optional[bool] = None,
71
- ) -> Union[Tuple, CausalLMOutputWithPast]:
72
-
73
- if inputs_embeds is None:
74
- (
75
- input_ids,
76
- position_ids,
77
- attention_mask,
78
- past_key_values,
79
- inputs_embeds,
80
- labels
81
- ) = self.prepare_inputs_labels_for_multimodal(
82
- input_ids,
83
- position_ids,
84
- attention_mask,
85
- past_key_values,
86
- labels,
87
- images,
88
- image_sizes
89
- )
90
-
91
- return super().forward(
92
- input_ids=input_ids,
93
- attention_mask=attention_mask,
94
- position_ids=position_ids,
95
- past_key_values=past_key_values,
96
- inputs_embeds=inputs_embeds,
97
- labels=labels,
98
- use_cache=use_cache,
99
- output_attentions=output_attentions,
100
- output_hidden_states=output_hidden_states,
101
- return_dict=return_dict
102
- )
103
-
104
- @torch.no_grad()
105
- def generate(
106
- self,
107
- inputs: Optional[torch.Tensor] = None,
108
- images: Optional[torch.Tensor] = None,
109
- image_sizes: Optional[torch.Tensor] = None,
110
- **kwargs,
111
- ) -> Union[GenerateOutput, torch.LongTensor]:
112
- position_ids = kwargs.pop("position_ids", None)
113
- attention_mask = kwargs.pop("attention_mask", None)
114
- if "inputs_embeds" in kwargs:
115
- raise NotImplementedError("`inputs_embeds` is not supported")
116
-
117
- if images is not None:
118
- (
119
- inputs,
120
- position_ids,
121
- attention_mask,
122
- _,
123
- inputs_embeds,
124
- _
125
- ) = self.prepare_inputs_labels_for_multimodal(
126
- inputs,
127
- position_ids,
128
- attention_mask,
129
- None,
130
- None,
131
- images,
132
- image_sizes=image_sizes
133
- )
134
- else:
135
- inputs_embeds = self.get_model().embed_tokens(inputs)
136
-
137
- return super().generate(
138
- position_ids=position_ids,
139
- attention_mask=attention_mask,
140
- inputs_embeds=inputs_embeds,
141
- **kwargs
142
- )
143
-
144
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
- inputs_embeds=None, **kwargs):
146
- images = kwargs.pop("images", None)
147
- image_sizes = kwargs.pop("image_sizes", None)
148
- inputs = super().prepare_inputs_for_generation(
149
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
- )
151
- if images is not None:
152
- inputs['images'] = images
153
- if image_sizes is not None:
154
- inputs['image_sizes'] = image_sizes
155
- return inputs
156
-
157
- AutoConfig.register("llava_mistral", LlavaMistralConfig)
158
- AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/language_model/llava_mpt.py DELETED
@@ -1,97 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import Optional, Tuple
17
-
18
- import torch
19
-
20
- from transformers import AutoConfig, AutoModelForCausalLM, \
21
- MptConfig, MptForCausalLM, MptModel
22
- from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23
-
24
-
25
- class LlavaMptConfig(MptConfig):
26
- model_type = "llava_mpt"
27
-
28
-
29
- class LlavaMptModel(LlavaMetaModel, MptModel):
30
- config_class = LlavaMptConfig
31
-
32
- def __init__(self, config: MptConfig):
33
- config.hidden_size = config.d_model
34
- super(LlavaMptModel, self).__init__(config)
35
-
36
- def embed_tokens(self, x):
37
- return self.wte(x)
38
-
39
-
40
- class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41
- config_class = LlavaMptConfig
42
- supports_gradient_checkpointing = True
43
-
44
- def __init__(self, config):
45
- super(MptForCausalLM, self).__init__(config)
46
-
47
- self.transformer = LlavaMptModel(config)
48
- self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
-
50
- # Initialize weights and apply final processing
51
- self.post_init()
52
-
53
- def get_model(self):
54
- return self.transformer
55
-
56
- def _set_gradient_checkpointing(self, module, value=False):
57
- if isinstance(module, LlavaMptModel):
58
- module.gradient_checkpointing = value
59
-
60
- def forward(
61
- self,
62
- input_ids: Optional[torch.LongTensor] = None,
63
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64
- attention_mask: Optional[torch.Tensor] = None,
65
- inputs_embeds: Optional[torch.Tensor] = None,
66
- labels: Optional[torch.Tensor] = None,
67
- use_cache: Optional[bool] = None,
68
- output_attentions: Optional[bool] = None,
69
- output_hidden_states: Optional[bool] = None,
70
- return_dict: Optional[bool] = None,
71
- images=None):
72
-
73
- input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74
-
75
- return super().forward(
76
- input_ids,
77
- past_key_values=past_key_values,
78
- attention_mask=attention_mask,
79
- inputs_embeds=inputs_embeds,
80
- labels=labels,
81
- use_cache=use_cache,
82
- output_attentions=output_attentions,
83
- output_hidden_states=output_hidden_states,
84
- return_dict=return_dict,
85
- )
86
-
87
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88
- images = kwargs.pop("images", None)
89
- _inputs = super().prepare_inputs_for_generation(
90
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91
- )
92
- _inputs['images'] = images
93
- return _inputs
94
-
95
-
96
- AutoConfig.register("llava_mpt", LlavaMptConfig)
97
- AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/language_model/llava_olmo.py DELETED
@@ -1,115 +0,0 @@
1
- import json
2
- import torch
3
- import llava.model.language_model.llava_olmo1p58b as llava_olmo ##
4
- import llava.model.language_model.llava_llama as llava_llama
5
-
6
- from OLMo_Bitnet_1B.modeling_olmo import OLMoForCausalLM
7
- from transformers import AutoModelForCausalLM, LlavaForConditionalGeneration, AutoTokenizer, pipeline, TextStreamer
8
- from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast
9
- from PIL import Image
10
- import requests
11
- from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
12
- from llava.conversation import conv_templates, SeparatorStyle
13
- from transformers import AutoProcessor
14
-
15
- def count_parameters(model):
16
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
17
-
18
- device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
19
- DEFAULT_IMAGE_TOKEN = "<image>"
20
- IMAGE_TOKEN_INDEX = -200
21
-
22
- '''
23
- # TO LOAD MODEL FROM CHECKPOINT PATH
24
- with open('/home/jsundara/work_disk/Models/LLaVA/checkpoints/llava-LlavaOLMoBitnet1B-Run3-finetune/config.json') as json_file:
25
- data = json.load(json_file)
26
-
27
- config_class = llava_olmo.LlavaOLMoBitnet1BConfig(**data)
28
- lolmo = llava_olmo.LlavaOLMoBitnet1BForCausalLM(config_class).to(device)
29
- lolmo.model.vision_tower.load_model()
30
- weight_checkpoint = torch.load('/home/jsundara/work_disk/Models/LLaVA/checkpoints/llava-LlavaOLMoBitnet1B-Run3-finetune/pytorch_model.bin')
31
- lolmo.load_state_dict(weight_checkpoint)
32
- '''
33
-
34
- lolmo = AutoModelForCausalLM.from_pretrained('IntelLabs/LlavaOLMoBitnet1B').to(device)
35
-
36
- llava_processor = AutoProcessor.from_pretrained('llava-hf/llava-1.5-13b-hf')
37
-
38
-
39
- image_processor = lolmo.model.vision_tower.image_processor
40
- tokenizer = AutoTokenizer.from_pretrained(
41
- "NousResearch/OLMo-Bitnet-1B",
42
- model_max_length=2048,
43
- padding_side="right",
44
- pad_token_id=1,
45
- use_fast=True,
46
- legacy=False,
47
- unk_token='<|padding|>',
48
- )
49
-
50
-
51
- url2 = "https://farm6.staticflickr.com/5218/5397878602_ef496d0159_z.jpg"
52
- url5 = "https://farm9.staticflickr.com/8263/8846712830_ac9887294b_z.jpg"
53
- url0 = "https://farm8.staticflickr.com/7147/6583432999_3ec6f513bd_z.jpg"
54
-
55
- url = "https://farm3.staticflickr.com/2157/2439959136_d932f4e816_z.jpg"
56
-
57
- url1 = "https://farm4.staticflickr.com/3775/9166874081_978dce0d74_z.jpg"
58
- image = Image.open(requests.get(url, stream=True).raw)
59
- image_tensor = process_images([image], image_processor, lolmo.config)[0]
60
-
61
- print(lolmo)
62
- print(f'L-OLMO num of params: {count_parameters(lolmo)}')
63
- print(f'LLM num of params: {count_parameters(lolmo.model.transformer)}')
64
- print(f'Connector num of params: {count_parameters(lolmo.model.mm_projector)}')
65
-
66
- text = "What are the four major tournaments of the sport shown in the image?"
67
- text = DEFAULT_IMAGE_TOKEN + '\n' + text
68
- conv = conv_templates['llava_v1'].copy()
69
- conv.append_message(conv.roles[0], text)
70
- conv.append_message(conv.roles[1], None)
71
- prompt = conv.get_prompt()
72
-
73
- text_tokens = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
74
-
75
- response = lolmo.generate(images=image_tensor.unsqueeze(0).to(device), inputs=text_tokens, max_new_tokens=400)
76
- decoded_text = tokenizer.batch_decode(response, skip_special_tokens=True)[0]
77
- print("\n\n", "-"*100)
78
- print(decoded_text[:decoded_text.find('</s>')])
79
- print("-"*100)
80
-
81
-
82
- #
83
- ##
84
- #
85
- #
86
- #
87
- '''
88
- # ORIGINAL CODE WITH ONLY OLMO:
89
- with open('llava/config.json') as json_file:
90
- data = json.load(json_file)
91
-
92
- text = "Paris is a historic city with architectural marvels. It is also "
93
- # text = ["Language modeling is "]
94
-
95
- config_class = llava_olmo.LlavaOLMoBitnet1BConfig(**data)
96
- lolmo = llava_olmo.LlavaOLMoBitnet1BForCausalLM(config_class).to(device)
97
- lolmo.load_state_dict(torch.load('OLMo_Bitnet_1B/pytorch_model.bin'), strict=False)
98
-
99
- olmo = OLMoForCausalLM(config_class).to(device)
100
- olmo.load_state_dict(torch.load('OLMo_Bitnet_1B/pytorch_model.bin'))
101
- actual_olmo = OLMoForCausalLM.from_pretrained("allenai/OLMo-1B").to(device)
102
-
103
- actual_olmo_tokenizer = OLMoTokenizerFast.from_pretrained("allenai/OLMo-1B")
104
- olmo_tokenizer = AutoTokenizer.from_pretrained("NousResearch/OLMo-Bitnet-1B")
105
-
106
- olmo_tokens = olmo_tokenizer(text, return_tensors='pt', return_token_type_ids=False).to(device)
107
- # olmo_tokens = actual_olmo_tokenizer(text, return_tensors='pt', return_token_type_ids=False).to(device)
108
-
109
-
110
- response = lolmo.generate(inputs=olmo_tokens['input_ids'], attention_mask=olmo_tokens['attention_mask'], max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
111
- # response = olmo.generate(inputs=olmo_tokens['input_ids'], attention_mask=olmo_tokens['attention_mask'], max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
112
-
113
-
114
- print(olmo_tokenizer.batch_decode(response, skip_special_tokens=True)[0])
115
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/language_model/llava_olmo1p58b.py DELETED
@@ -1,164 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
-
21
- from transformers import AutoConfig, AutoModelForCausalLM
22
-
23
- from OLMo_Bitnet_1B.model import OLMo
24
- from OLMo_Bitnet_1B.configuration_olmo import OLMoConfig
25
- from OLMo_Bitnet_1B.modeling_olmo import OLMoForCausalLM
26
-
27
-
28
- from transformers.modeling_outputs import CausalLMOutputWithPast
29
- from transformers.generation.utils import GenerateOutput
30
-
31
- from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
32
-
33
-
34
- class LlavaOLMoBitnet1BConfig(OLMoConfig):
35
- model_type = "IntelLabs/LlavaOLMoBitnet1B"
36
-
37
-
38
- class LlavaOLMoBitnet1BModel(LlavaMetaModel, OLMo):
39
- config_class = LlavaOLMoBitnet1BConfig
40
-
41
- def __init__(self, config: OLMoConfig):
42
- super(LlavaOLMoBitnet1BModel, self).__init__(config)
43
-
44
-
45
- class LlavaOLMoBitnet1BForCausalLM(OLMoForCausalLM, LlavaMetaForCausalLM):
46
- config_class = LlavaOLMoBitnet1BConfig
47
-
48
- def __init__(self, config):
49
- super(OLMoForCausalLM, self).__init__(config)
50
- self.model = LlavaOLMoBitnet1BModel(config)
51
- self.vocab_size = config.vocab_size
52
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
53
- self.embed_layer = self.get_input_embeddings()
54
- self.model.vision_tower.load_model()
55
-
56
- # Initialize weights and apply final processing
57
- self.post_init()
58
-
59
- def get_model(self):
60
- return self.model
61
-
62
- def forward(
63
- self,
64
- input_ids: torch.LongTensor = None,
65
- attention_mask: Optional[torch.Tensor] = None,
66
- position_ids: Optional[torch.LongTensor] = None,
67
- past_key_values: Optional[List[torch.FloatTensor]] = None,
68
- inputs_embeds: Optional[torch.FloatTensor] = None,
69
- labels: Optional[torch.LongTensor] = None,
70
- use_cache: Optional[bool] = None,
71
- output_attentions: Optional[bool] = None,
72
- output_hidden_states: Optional[bool] = None,
73
- images: Optional[torch.FloatTensor] = None,
74
- image_sizes: Optional[List[List[int]]] = None,
75
- return_dict: Optional[bool] = None,
76
- ) -> Union[Tuple, CausalLMOutputWithPast]:
77
-
78
- if inputs_embeds is None:
79
- (
80
- input_ids,
81
- position_ids,
82
- attention_mask,
83
- past_key_values,
84
- inputs_embeds,
85
- labels
86
- ) = self.prepare_inputs_labels_for_multimodal(
87
- input_ids,
88
- position_ids,
89
- attention_mask,
90
- past_key_values,
91
- labels,
92
- images,
93
- image_sizes
94
- )
95
-
96
- return super().forward(
97
- input_ids=input_ids,
98
- attention_mask=attention_mask,
99
- # position_ids=position_ids,
100
- past_key_values=past_key_values,
101
- inputs_embeds=inputs_embeds,
102
- labels=labels,
103
- use_cache=use_cache,
104
- output_attentions=output_attentions,
105
- output_hidden_states=output_hidden_states,
106
- return_dict=return_dict
107
- )
108
-
109
- @torch.no_grad()
110
- def generate(
111
- self,
112
- inputs: Optional[torch.Tensor] = None,
113
- images: Optional[torch.Tensor] = None,
114
- image_sizes: Optional[torch.Tensor] = None,
115
- **kwargs,
116
- ) -> Union[GenerateOutput, torch.LongTensor]:
117
- position_ids = kwargs.pop("position_ids", None)
118
- attention_mask = kwargs.pop("attention_mask", None)
119
- if "inputs_embeds" in kwargs:
120
- raise NotImplementedError("`inputs_embeds` is not supported")
121
-
122
- if images is not None:
123
- (
124
- inputs,
125
- position_ids,
126
- attention_mask,
127
- _,
128
- inputs_embeds,
129
- _
130
- ) = self.prepare_inputs_labels_for_multimodal(
131
- inputs,
132
- position_ids,
133
- attention_mask,
134
- None,
135
- None,
136
- images,
137
- image_sizes=image_sizes
138
- )
139
- else:
140
- # inputs_embeds = self.get_model().embed_tokens(inputs)
141
- inputs_embeds = self.embed_layer(inputs)
142
-
143
- return super().generate(
144
- position_ids=position_ids,
145
- attention_mask=attention_mask,
146
- inputs_embeds=inputs_embeds,
147
- **kwargs
148
- )
149
-
150
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
151
- inputs_embeds=None, **kwargs):
152
- images = kwargs.pop("images", None)
153
- image_sizes = kwargs.pop("image_sizes", None)
154
- inputs = super().prepare_inputs_for_generation(
155
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
156
- )
157
- if images is not None:
158
- inputs['images'] = images
159
- if image_sizes is not None:
160
- inputs['image_sizes'] = image_sizes
161
- return inputs
162
-
163
- AutoConfig.register("IntelLabs/LlavaOLMoBitnet1B", LlavaOLMoBitnet1BConfig)
164
- AutoModelForCausalLM.register(LlavaOLMoBitnet1BConfig, LlavaOLMoBitnet1BForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/llava_arch.py DELETED
@@ -1,369 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from abc import ABC, abstractmethod
17
-
18
- import torch
19
- import torch.nn as nn
20
-
21
- from .multimodal_encoder.builder import build_vision_tower
22
- from .multimodal_projector.builder import build_vision_projector
23
-
24
- from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
-
26
- from llava.mm_utils import get_anyres_image_grid_shape
27
-
28
-
29
- class LlavaMetaModel:
30
-
31
- def __init__(self, config):
32
- super(LlavaMetaModel, self).__init__(config)
33
-
34
- if hasattr(config, "mm_vision_tower"):
35
- self.vision_tower = build_vision_tower(config, delay_load=True)
36
- self.mm_projector = build_vision_projector(config)
37
-
38
- if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
39
- self.image_newline = nn.Parameter(
40
- torch.empty(config.hidden_size, dtype=self.dtype)
41
- )
42
-
43
- def get_vision_tower(self):
44
- vision_tower = getattr(self, 'vision_tower', None)
45
- if type(vision_tower) is list:
46
- vision_tower = vision_tower[0]
47
- return vision_tower
48
-
49
- def initialize_vision_modules(self, model_args, fsdp=None):
50
- vision_tower = model_args.vision_tower
51
- mm_vision_select_layer = model_args.mm_vision_select_layer
52
- mm_vision_select_feature = model_args.mm_vision_select_feature
53
- pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
54
- mm_patch_merge_type = model_args.mm_patch_merge_type
55
-
56
- self.config.mm_vision_tower = vision_tower
57
-
58
- if self.get_vision_tower() is None:
59
- vision_tower = build_vision_tower(model_args)
60
-
61
- if fsdp is not None and len(fsdp) > 0:
62
- self.vision_tower = [vision_tower]
63
- else:
64
- self.vision_tower = vision_tower
65
- else:
66
- if fsdp is not None and len(fsdp) > 0:
67
- vision_tower = self.vision_tower[0]
68
- else:
69
- vision_tower = self.vision_tower
70
- vision_tower.load_model()
71
-
72
- self.config.use_mm_proj = True
73
- self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
74
- self.config.mm_hidden_size = vision_tower.hidden_size
75
- self.config.mm_vision_select_layer = mm_vision_select_layer
76
- self.config.mm_vision_select_feature = mm_vision_select_feature
77
- self.config.mm_patch_merge_type = mm_patch_merge_type
78
-
79
- if getattr(self, 'mm_projector', None) is None:
80
- self.mm_projector = build_vision_projector(self.config)
81
-
82
- if 'unpad' in mm_patch_merge_type:
83
- embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
84
- self.image_newline = nn.Parameter(
85
- torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
86
- )
87
- else:
88
- # In case it is frozen by LoRA
89
- for p in self.mm_projector.parameters():
90
- p.requires_grad = True
91
-
92
- if pretrain_mm_mlp_adapter is not None:
93
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
94
- def get_w(weights, keyword):
95
- return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
96
-
97
- self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
98
-
99
-
100
- def unpad_image(tensor, original_size):
101
- """
102
- Unpads a PyTorch tensor of a padded and resized image.
103
-
104
- Args:
105
- tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
106
- original_size (tuple): The original size of PIL image (width, height).
107
-
108
- Returns:
109
- torch.Tensor: The unpadded image tensor.
110
- """
111
- original_width, original_height = original_size
112
- current_height, current_width = tensor.shape[1:]
113
-
114
- original_aspect_ratio = original_width / original_height
115
- current_aspect_ratio = current_width / current_height
116
-
117
- if original_aspect_ratio > current_aspect_ratio:
118
- scale_factor = current_width / original_width
119
- new_height = int(original_height * scale_factor)
120
- padding = (current_height - new_height) // 2
121
- unpadded_tensor = tensor[:, padding:current_height - padding, :]
122
- else:
123
- scale_factor = current_height / original_height
124
- new_width = int(original_width * scale_factor)
125
- padding = (current_width - new_width) // 2
126
- unpadded_tensor = tensor[:, :, padding:current_width - padding]
127
-
128
- return unpadded_tensor
129
-
130
-
131
- class LlavaMetaForCausalLM(ABC):
132
-
133
- @abstractmethod
134
- def get_model(self):
135
- pass
136
-
137
- def get_vision_tower(self):
138
- return self.get_model().get_vision_tower()
139
-
140
- def encode_images(self, images):
141
- image_features = self.get_model().get_vision_tower()(images)
142
- image_features = image_features.to(self.get_model().device) # add for local inference runs
143
- image_features = self.get_model().mm_projector(image_features)
144
- return image_features
145
-
146
- def prepare_inputs_labels_for_multimodal(
147
- self, input_ids, position_ids, attention_mask, past_key_values, labels,
148
- images, image_sizes=None
149
- ):
150
- vision_tower = self.get_vision_tower()
151
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
152
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
153
-
154
- if type(images) is list or images.ndim == 5:
155
- if type(images) is list:
156
- images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
157
- concat_images = torch.cat([image for image in images], dim=0)
158
- image_features = self.encode_images(concat_images)
159
- split_sizes = [image.shape[0] for image in images]
160
- image_features = torch.split(image_features, split_sizes, dim=0)
161
- mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
162
- image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
163
- if mm_patch_merge_type == 'flat':
164
- image_features = [x.flatten(0, 1) for x in image_features]
165
- elif mm_patch_merge_type.startswith('spatial'):
166
- new_image_features = []
167
- for image_idx, image_feature in enumerate(image_features):
168
- if image_feature.shape[0] > 1:
169
- base_image_feature = image_feature[0]
170
- image_feature = image_feature[1:]
171
- height = width = self.get_vision_tower().num_patches_per_side
172
- assert height * width == base_image_feature.shape[0]
173
- if image_aspect_ratio == 'anyres':
174
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
175
- image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
176
- else:
177
- raise NotImplementedError
178
- if 'unpad' in mm_patch_merge_type:
179
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
180
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
181
- image_feature = unpad_image(image_feature, image_sizes[image_idx])
182
- image_feature = torch.cat((
183
- image_feature,
184
- self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
185
- ), dim=-1)
186
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
187
- else:
188
- image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
189
- image_feature = image_feature.flatten(0, 3)
190
- image_feature = torch.cat((base_image_feature, image_feature), dim=0)
191
- else:
192
- image_feature = image_feature[0]
193
- if 'unpad' in mm_patch_merge_type:
194
- image_feature = torch.cat((
195
- image_feature,
196
- self.model.image_newline[None].to(image_feature.device)
197
- ), dim=0)
198
- new_image_features.append(image_feature)
199
- image_features = new_image_features
200
- else:
201
- raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
202
- else:
203
- image_features = self.encode_images(images)
204
-
205
- # TODO: image start / end is not implemented here to support pretraining.
206
- if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
207
- raise NotImplementedError
208
-
209
- # Let's just add dummy tensors if they do not exist,
210
- # it is a headache to deal with None all the time.
211
- # But it is not ideal, and if you have a better idea,
212
- # please open an issue / submit a PR, thanks.
213
- _labels = labels
214
- _position_ids = position_ids
215
- _attention_mask = attention_mask
216
- if attention_mask is None:
217
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
218
- else:
219
- attention_mask = attention_mask.bool()
220
- if position_ids is None:
221
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
222
- if labels is None:
223
- labels = torch.full_like(input_ids, IGNORE_INDEX)
224
-
225
- # remove the padding using attention_mask -- FIXME
226
- _input_ids = input_ids
227
- input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
228
- labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
229
-
230
- new_input_embeds = []
231
- new_labels = []
232
- cur_image_idx = 0
233
- for batch_idx, cur_input_ids in enumerate(input_ids):
234
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
235
- if num_images == 0:
236
- cur_image_features = image_features[cur_image_idx]
237
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
238
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
239
- new_input_embeds.append(cur_input_embeds)
240
- new_labels.append(labels[batch_idx])
241
- cur_image_idx += 1
242
- continue
243
-
244
- image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
245
- cur_input_ids_noim = []
246
- cur_labels = labels[batch_idx]
247
- cur_labels_noim = []
248
- for i in range(len(image_token_indices) - 1):
249
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
250
- cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
251
- split_sizes = [x.shape[0] for x in cur_labels_noim]
252
- cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
253
- cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
254
- cur_new_input_embeds = []
255
- cur_new_labels = []
256
-
257
- for i in range(num_images + 1):
258
- cur_new_input_embeds.append(cur_input_embeds_no_im[i])
259
- cur_new_labels.append(cur_labels_noim[i])
260
- if i < num_images:
261
- cur_image_features = image_features[cur_image_idx]
262
- cur_image_idx += 1
263
- cur_new_input_embeds.append(cur_image_features)
264
- cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
265
-
266
- cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
267
-
268
- cur_new_input_embeds = torch.cat(cur_new_input_embeds)
269
- cur_new_labels = torch.cat(cur_new_labels)
270
-
271
- new_input_embeds.append(cur_new_input_embeds)
272
- new_labels.append(cur_new_labels)
273
-
274
- # Truncate sequences to max length as image embeddings can make the sequence longer
275
- tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
276
- if tokenizer_model_max_length is not None:
277
- new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
278
- new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
279
-
280
- # Combine them
281
- max_len = max(x.shape[0] for x in new_input_embeds)
282
- batch_size = len(new_input_embeds)
283
-
284
- new_input_embeds_padded = []
285
- new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
286
- attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
287
- position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
288
-
289
- for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
290
- cur_len = cur_new_embed.shape[0]
291
- if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
292
- new_input_embeds_padded.append(torch.cat((
293
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
294
- cur_new_embed
295
- ), dim=0))
296
- if cur_len > 0:
297
- new_labels_padded[i, -cur_len:] = cur_new_labels
298
- attention_mask[i, -cur_len:] = True
299
- position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
300
- else:
301
- new_input_embeds_padded.append(torch.cat((
302
- cur_new_embed,
303
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
304
- ), dim=0))
305
- if cur_len > 0:
306
- new_labels_padded[i, :cur_len] = cur_new_labels
307
- attention_mask[i, :cur_len] = True
308
- position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
309
-
310
- new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
311
-
312
- if _labels is None:
313
- new_labels = None
314
- else:
315
- new_labels = new_labels_padded
316
-
317
- if _attention_mask is None:
318
- attention_mask = None
319
- else:
320
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
321
-
322
- if _position_ids is None:
323
- position_ids = None
324
-
325
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
326
-
327
- def initialize_vision_tokenizer(self, model_args, tokenizer):
328
- if model_args.mm_use_im_patch_token:
329
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
330
- self.resize_token_embeddings(len(tokenizer))
331
-
332
- if model_args.mm_use_im_start_end:
333
- num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
334
- self.resize_token_embeddings(len(tokenizer))
335
-
336
- if num_new_tokens > 0:
337
- input_embeddings = self.get_input_embeddings().weight.data
338
- output_embeddings = self.get_output_embeddings().weight.data
339
-
340
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
341
- dim=0, keepdim=True)
342
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
343
- dim=0, keepdim=True)
344
-
345
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
346
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
347
-
348
- if model_args.tune_mm_mlp_adapter:
349
- for p in self.get_input_embeddings().parameters():
350
- p.requires_grad = True
351
- for p in self.get_output_embeddings().parameters():
352
- p.requires_grad = False
353
-
354
- if model_args.pretrain_mm_mlp_adapter:
355
- mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
356
- embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
357
- assert num_new_tokens == 2
358
- if input_embeddings.shape == embed_tokens_weight.shape:
359
- input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
360
- elif embed_tokens_weight.shape[0] == num_new_tokens:
361
- input_embeddings[-num_new_tokens:] = embed_tokens_weight
362
- else:
363
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
364
- elif model_args.mm_use_im_patch_token:
365
- if model_args.tune_mm_mlp_adapter:
366
- for p in self.get_input_embeddings().parameters():
367
- p.requires_grad = False
368
- for p in self.get_output_embeddings().parameters():
369
- p.requires_grad = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/make_delta.py DELETED
@@ -1,52 +0,0 @@
1
- """
2
- Usage:
3
- python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
- """
5
- import argparse
6
-
7
- import torch
8
- from tqdm import tqdm
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
- from llava.model.utils import auto_upgrade
11
-
12
-
13
- def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
- print("Loading base model")
15
- base = AutoModelForCausalLM.from_pretrained(
16
- base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
-
18
- print("Loading target model")
19
- auto_upgrade(target_model_path)
20
- target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
-
22
- print("Calculating delta")
23
- for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
- if name not in base.state_dict():
25
- assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
- continue
27
- if param.data.shape == base.state_dict()[name].shape:
28
- param.data -= base.state_dict()[name]
29
- else:
30
- assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
- bparam = base.state_dict()[name]
32
- param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
-
34
- print("Saving delta")
35
- if hub_repo_id:
36
- kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
- else:
38
- kwargs = {}
39
- target.save_pretrained(delta_path, **kwargs)
40
- target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
- target_tokenizer.save_pretrained(delta_path, **kwargs)
42
-
43
-
44
- if __name__ == "__main__":
45
- parser = argparse.ArgumentParser()
46
- parser.add_argument("--base-model-path", type=str, required=True)
47
- parser.add_argument("--target-model-path", type=str, required=True)
48
- parser.add_argument("--delta-path", type=str, required=True)
49
- parser.add_argument("--hub-repo-id", type=str, default=None)
50
- args = parser.parse_args()
51
-
52
- make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/multimodal_encoder/__pycache__/builder.cpython-310.pyc DELETED
Binary file (730 Bytes)
 
model/multimodal_encoder/__pycache__/builder.cpython-311.pyc DELETED
Binary file (1.24 kB)
 
model/multimodal_encoder/__pycache__/builder.cpython-312.pyc DELETED
Binary file (1.07 kB)
 
model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc DELETED
Binary file (5.31 kB)
 
model/multimodal_encoder/__pycache__/clip_encoder.cpython-311.pyc DELETED
Binary file (9.97 kB)
 
model/multimodal_encoder/__pycache__/clip_encoder.cpython-312.pyc DELETED
Binary file (9.8 kB)
 
model/multimodal_encoder/builder.py DELETED
@@ -1,15 +0,0 @@
1
- import os
2
- from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3
-
4
-
5
- def build_vision_tower(vision_tower_cfg, **kwargs):
6
- vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
- is_absolute_path_exists = os.path.exists(vision_tower)
8
- use_s2 = getattr(vision_tower_cfg, 's2', False)
9
- if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
10
- if use_s2:
11
- return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
12
- else:
13
- return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
14
-
15
- raise ValueError(f'Unknown vision tower: {vision_tower}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/multimodal_encoder/clip_encoder.py DELETED
@@ -1,147 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
-
6
-
7
- class CLIPVisionTower(nn.Module):
8
- def __init__(self, vision_tower, args, delay_load=False):
9
- super().__init__()
10
-
11
- self.is_loaded = False
12
-
13
- self.vision_tower_name = vision_tower
14
- self.select_layer = args.mm_vision_select_layer
15
- self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
-
17
- if not delay_load:
18
- self.load_model()
19
- elif getattr(args, 'unfreeze_mm_vision_tower', False):
20
- self.load_model()
21
- else:
22
- self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23
-
24
- def load_model(self, device_map=None):
25
- if self.is_loaded:
26
- print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27
- return
28
-
29
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
30
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
31
- self.vision_tower.requires_grad_(False)
32
-
33
- self.is_loaded = True
34
-
35
- def feature_select(self, image_forward_outs):
36
- image_features = image_forward_outs.hidden_states[self.select_layer]
37
- if self.select_feature == 'patch':
38
- image_features = image_features[:, 1:]
39
- elif self.select_feature == 'cls_patch':
40
- image_features = image_features
41
- else:
42
- raise ValueError(f'Unexpected select feature: {self.select_feature}')
43
- return image_features
44
-
45
- @torch.no_grad()
46
- def forward(self, images):
47
- if type(images) is list:
48
- image_features = []
49
- for image in images:
50
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
52
- image_features.append(image_feature)
53
- else:
54
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
56
-
57
- return image_features
58
-
59
- @property
60
- def dummy_feature(self):
61
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62
-
63
- @property
64
- def dtype(self):
65
- return self.vision_tower.dtype
66
-
67
- @property
68
- def device(self):
69
- return self.vision_tower.device
70
-
71
- @property
72
- def config(self):
73
- if self.is_loaded:
74
- return self.vision_tower.config
75
- else:
76
- return self.cfg_only
77
-
78
- @property
79
- def hidden_size(self):
80
- return self.config.hidden_size
81
-
82
- @property
83
- def num_patches_per_side(self):
84
- return self.config.image_size // self.config.patch_size
85
-
86
- @property
87
- def num_patches(self):
88
- return (self.config.image_size // self.config.patch_size) ** 2
89
-
90
-
91
-
92
- class CLIPVisionTowerS2(CLIPVisionTower):
93
- def __init__(self, vision_tower, args, delay_load=False):
94
- super().__init__(vision_tower, args, delay_load)
95
-
96
- self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
97
- self.s2_scales = list(map(int, self.s2_scales.split(',')))
98
- self.s2_scales.sort()
99
- self.s2_split_size = self.s2_scales[0]
100
- self.s2_image_size = self.s2_scales[-1]
101
-
102
- try:
103
- from s2wrapper import forward as multiscale_forward
104
- except ImportError:
105
- raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
106
- self.multiscale_forward = multiscale_forward
107
-
108
- # change resize/crop size in preprocessing to the largest image size in s2_scale
109
- if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
110
- self.image_processor.size['shortest_edge'] = self.s2_image_size
111
- self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
112
-
113
- def load_model(self, device_map=None):
114
- if self.is_loaded:
115
- print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
116
- return
117
-
118
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
119
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
120
- self.vision_tower.requires_grad_(False)
121
-
122
- self.image_processor.size['shortest_edge'] = self.s2_image_size
123
- self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
124
-
125
- self.is_loaded = True
126
-
127
- @torch.no_grad()
128
- def forward_feature(self, images):
129
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
130
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
131
- return image_features
132
-
133
- @torch.no_grad()
134
- def forward(self, images):
135
- if type(images) is list:
136
- image_features = []
137
- for image in images:
138
- image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
139
- image_features.append(image_feature)
140
- else:
141
- image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
142
-
143
- return image_features
144
-
145
- @property
146
- def hidden_size(self):
147
- return self.config.hidden_size * len(self.s2_scales)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/multimodal_projector/__pycache__/builder.cpython-310.pyc DELETED
Binary file (2.02 kB)
 
model/multimodal_projector/__pycache__/builder.cpython-311.pyc DELETED
Binary file (3.6 kB)
 
model/multimodal_projector/__pycache__/builder.cpython-312.pyc DELETED
Binary file (3.22 kB)
 
model/multimodal_projector/builder.py DELETED
@@ -1,51 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import re
4
-
5
-
6
- class IdentityMap(nn.Module):
7
- def __init__(self):
8
- super().__init__()
9
-
10
- def forward(self, x, *args, **kwargs):
11
- return x
12
-
13
- @property
14
- def config(self):
15
- return {"mm_projector_type": 'identity'}
16
-
17
-
18
- class SimpleResBlock(nn.Module):
19
- def __init__(self, channels):
20
- super().__init__()
21
- self.pre_norm = nn.LayerNorm(channels)
22
-
23
- self.proj = nn.Sequential(
24
- nn.Linear(channels, channels),
25
- nn.GELU(),
26
- nn.Linear(channels, channels)
27
- )
28
- def forward(self, x):
29
- x = self.pre_norm(x)
30
- return x + self.proj(x)
31
-
32
-
33
- def build_vision_projector(config, delay_load=False, **kwargs):
34
- projector_type = getattr(config, 'mm_projector_type', 'linear')
35
-
36
- if projector_type == 'linear':
37
- return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
-
39
- mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40
- if mlp_gelu_match:
41
- mlp_depth = int(mlp_gelu_match.group(1))
42
- modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
- for _ in range(1, mlp_depth):
44
- modules.append(nn.GELU())
45
- modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
- return nn.Sequential(*modules)
47
-
48
- if projector_type == 'identity':
49
- return IdentityMap()
50
-
51
- raise ValueError(f'Unknown projector type: {projector_type}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/utils.py DELETED
@@ -1,20 +0,0 @@
1
- from transformers import AutoConfig
2
-
3
-
4
- def auto_upgrade(config):
5
- cfg = AutoConfig.from_pretrained(config)
6
- if 'llava' in config and 'llava' not in cfg.model_type:
7
- assert cfg.model_type == 'llama'
8
- print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
- print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
- confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
- if confirm.lower() in ["y", "yes"]:
12
- print("Upgrading checkpoint...")
13
- assert len(cfg.architectures) == 1
14
- setattr(cfg.__class__, "model_type", "llava")
15
- cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
- cfg.save_pretrained(config)
17
- print("Checkpoint upgraded.")
18
- else:
19
- print("Checkpoint upgrade aborted.")
20
- exit(1)