Upload files with `vila-upload`.
Browse filesUpload media_encoder.py
Upload media.py
Upload modeling_vila.py
Upload configuration_vila.py
Upload builder.py
Upload mm_utils.py
Upload tokenizer_utils.py
Upload siglip_encoder.py
- builder.py +14 -4
- configuration_vila.py +16 -8
- media.py +4 -0
- media_encoder.py +3 -2
- mm_utils.py +1 -1
- modeling_vila.py +131 -35
- siglip_encoder.py +2 -3
- tokenizer_utils.py +2 -2
builder.py
CHANGED
@@ -22,9 +22,9 @@ from dataclasses import asdict
|
|
22 |
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
23 |
|
24 |
import torch
|
|
|
25 |
from huggingface_hub import file_exists, repo_exists
|
26 |
from huggingface_hub.utils import HFValidationError
|
27 |
-
import transformers
|
28 |
from transformers import (
|
29 |
AutoConfig,
|
30 |
AutoModelForCausalLM,
|
@@ -33,8 +33,9 @@ from transformers import (
|
|
33 |
PreTrainedModel,
|
34 |
PreTrainedTokenizer,
|
35 |
)
|
|
|
36 |
# from .conversation import *
|
37 |
-
from .conversation import
|
38 |
|
39 |
SENTINEL_TOKEN = "<vila/sentinel>"
|
40 |
MEDIA_TOKENS = {
|
@@ -51,9 +52,11 @@ DUMMY_CONVERSATION = [
|
|
51 |
{"from": "gpt", "value": "answer"},
|
52 |
] * 10
|
53 |
|
|
|
54 |
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
|
55 |
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
56 |
-
|
|
|
57 |
def has_tokenizer(repo_id_or_path: str) -> bool:
|
58 |
# Check if the tokenizer is in a local directory
|
59 |
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
|
@@ -65,12 +68,14 @@ def has_tokenizer(repo_id_or_path: str) -> bool:
|
|
65 |
except HFValidationError:
|
66 |
return False
|
67 |
|
|
|
68 |
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
69 |
if not hasattr(tokenizer, "sentinel_token"):
|
70 |
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
|
71 |
tokenizer.sentinel_token = SENTINEL_TOKEN
|
72 |
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
|
73 |
|
|
|
74 |
def tokenize_conversation_legacy(
|
75 |
messages: Sequence[Dict[str, str]],
|
76 |
tokenizer: transformers.PreTrainedTokenizer,
|
@@ -103,6 +108,7 @@ def tokenize_conversation_legacy(
|
|
103 |
|
104 |
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
105 |
|
|
|
106 |
def tokenize_conversation(
|
107 |
messages: Sequence[Dict[str, str]],
|
108 |
tokenizer: transformers.PreTrainedTokenizer,
|
@@ -148,6 +154,7 @@ def tokenize_conversation(
|
|
148 |
)
|
149 |
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
|
150 |
|
|
|
151 |
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
152 |
_maybe_add_sentinel_token(tokenizer)
|
153 |
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
|
@@ -159,6 +166,7 @@ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
|
159 |
stop_tokens.add(stop_token)
|
160 |
return list(stop_tokens)
|
161 |
|
|
|
162 |
def context_length_extension(config):
|
163 |
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
164 |
model_max_length = getattr(config, "model_max_length", None)
|
@@ -186,7 +194,7 @@ def build_llm_and_tokenizer(
|
|
186 |
|
187 |
# Quantization related
|
188 |
quantization_restore_from_checkpoint = False
|
189 |
-
|
190 |
if quantization_restore_from_checkpoint:
|
191 |
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
|
192 |
|
@@ -215,6 +223,8 @@ def build_llm_and_tokenizer(
|
|
215 |
if getattr(config, "chat_template", None) is not None:
|
216 |
print(f"Using chat template: {config.chat_template}")
|
217 |
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
|
|
|
|
|
218 |
with open(fpath) as fd:
|
219 |
chat_template = fd.read()
|
220 |
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
|
|
|
22 |
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
23 |
|
24 |
import torch
|
25 |
+
import transformers
|
26 |
from huggingface_hub import file_exists, repo_exists
|
27 |
from huggingface_hub.utils import HFValidationError
|
|
|
28 |
from transformers import (
|
29 |
AutoConfig,
|
30 |
AutoModelForCausalLM,
|
|
|
33 |
PreTrainedModel,
|
34 |
PreTrainedTokenizer,
|
35 |
)
|
36 |
+
|
37 |
# from .conversation import *
|
38 |
+
from .conversation import SeparatorStyle, default_conversation
|
39 |
|
40 |
SENTINEL_TOKEN = "<vila/sentinel>"
|
41 |
MEDIA_TOKENS = {
|
|
|
52 |
{"from": "gpt", "value": "answer"},
|
53 |
] * 10
|
54 |
|
55 |
+
|
56 |
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
|
57 |
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
58 |
+
|
59 |
+
|
60 |
def has_tokenizer(repo_id_or_path: str) -> bool:
|
61 |
# Check if the tokenizer is in a local directory
|
62 |
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
|
|
|
68 |
except HFValidationError:
|
69 |
return False
|
70 |
|
71 |
+
|
72 |
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
73 |
if not hasattr(tokenizer, "sentinel_token"):
|
74 |
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
|
75 |
tokenizer.sentinel_token = SENTINEL_TOKEN
|
76 |
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
|
77 |
|
78 |
+
|
79 |
def tokenize_conversation_legacy(
|
80 |
messages: Sequence[Dict[str, str]],
|
81 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
108 |
|
109 |
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
110 |
|
111 |
+
|
112 |
def tokenize_conversation(
|
113 |
messages: Sequence[Dict[str, str]],
|
114 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
154 |
)
|
155 |
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
|
156 |
|
157 |
+
|
158 |
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
159 |
_maybe_add_sentinel_token(tokenizer)
|
160 |
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
|
|
|
166 |
stop_tokens.add(stop_token)
|
167 |
return list(stop_tokens)
|
168 |
|
169 |
+
|
170 |
def context_length_extension(config):
|
171 |
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
172 |
model_max_length = getattr(config, "model_max_length", None)
|
|
|
194 |
|
195 |
# Quantization related
|
196 |
quantization_restore_from_checkpoint = False
|
197 |
+
|
198 |
if quantization_restore_from_checkpoint:
|
199 |
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
|
200 |
|
|
|
223 |
if getattr(config, "chat_template", None) is not None:
|
224 |
print(f"Using chat template: {config.chat_template}")
|
225 |
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
|
226 |
+
if not os.path.exists(fpath):
|
227 |
+
fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
|
228 |
with open(fpath) as fd:
|
229 |
chat_template = fd.read()
|
230 |
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
|
configuration_vila.py
CHANGED
@@ -1,15 +1,24 @@
|
|
|
|
1 |
import math
|
|
|
|
|
|
|
|
|
2 |
from typing import List, Optional
|
3 |
-
|
4 |
import torch
|
5 |
import torchvision
|
6 |
-
import os, os.path as osp
|
7 |
-
|
8 |
-
from threading import Thread
|
9 |
-
from copy import deepcopy
|
10 |
from PIL import Image
|
11 |
-
from transformers import
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class VILAConfig(PretrainedConfig):
|
15 |
model_type = "vila"
|
@@ -82,4 +91,3 @@ class VILAConfig(PretrainedConfig):
|
|
82 |
self.video_encoder = video_encoder
|
83 |
|
84 |
super().__init__(**kwargs)
|
85 |
-
|
|
|
1 |
+
import json
|
2 |
import math
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
from copy import deepcopy
|
6 |
+
from threading import Thread
|
7 |
from typing import List, Optional
|
8 |
+
|
9 |
import torch
|
10 |
import torchvision
|
|
|
|
|
|
|
|
|
11 |
from PIL import Image
|
12 |
+
from transformers import (
|
13 |
+
AutoProcessor,
|
14 |
+
PretrainedConfig,
|
15 |
+
PreTrainedModel,
|
16 |
+
Qwen2Config,
|
17 |
+
Qwen2ForCausalLM,
|
18 |
+
Qwen2PreTrainedModel,
|
19 |
+
TextIteratorStreamer,
|
20 |
+
)
|
21 |
+
|
22 |
|
23 |
class VILAConfig(PretrainedConfig):
|
24 |
model_type = "vila"
|
|
|
91 |
self.video_encoder = video_encoder
|
92 |
|
93 |
super().__init__(**kwargs)
|
|
media.py
CHANGED
@@ -20,13 +20,16 @@ MEDIA_TOKENS = {
|
|
20 |
"video": "<vila/video>",
|
21 |
}
|
22 |
|
|
|
23 |
class Media:
|
24 |
pass
|
25 |
|
|
|
26 |
class File(Media):
|
27 |
def __init__(self, path: str) -> None:
|
28 |
self.path = path
|
29 |
|
|
|
30 |
class Image(File):
|
31 |
pass
|
32 |
|
@@ -34,6 +37,7 @@ class Image(File):
|
|
34 |
class Video(File):
|
35 |
pass
|
36 |
|
|
|
37 |
def make_list(obj: Any) -> List:
|
38 |
return obj if isinstance(obj, list) else [obj]
|
39 |
|
|
|
20 |
"video": "<vila/video>",
|
21 |
}
|
22 |
|
23 |
+
|
24 |
class Media:
|
25 |
pass
|
26 |
|
27 |
+
|
28 |
class File(Media):
|
29 |
def __init__(self, path: str) -> None:
|
30 |
self.path = path
|
31 |
|
32 |
+
|
33 |
class Image(File):
|
34 |
pass
|
35 |
|
|
|
37 |
class Video(File):
|
38 |
pass
|
39 |
|
40 |
+
|
41 |
def make_list(obj: Any) -> List:
|
42 |
return obj if isinstance(obj, list) else [obj]
|
43 |
|
media_encoder.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
from functools import partial
|
4 |
from typing import Any, Dict, List, Optional
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
class BaseEncoder(nn.Module):
|
8 |
def __init__(self, parent: nn.Module) -> None:
|
|
|
|
|
|
|
1 |
from functools import partial
|
2 |
from typing import Any, Dict, List, Optional
|
3 |
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
|
8 |
class BaseEncoder(nn.Module):
|
9 |
def __init__(self, parent: nn.Module) -> None:
|
mm_utils.py
CHANGED
@@ -26,7 +26,7 @@ import torch
|
|
26 |
from PIL import Image
|
27 |
from transformers import StoppingCriteria
|
28 |
|
29 |
-
from
|
30 |
|
31 |
|
32 |
def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
|
|
26 |
from PIL import Image
|
27 |
from transformers import StoppingCriteria
|
28 |
|
29 |
+
from .constants import DEFAULT_IMAGE_TOKEN
|
30 |
|
31 |
|
32 |
def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
modeling_vila.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import shutil
|
2 |
import copy
|
3 |
import json
|
4 |
import logging
|
@@ -6,6 +5,7 @@ import math
|
|
6 |
import os
|
7 |
import os.path
|
8 |
import os.path as osp
|
|
|
9 |
import warnings
|
10 |
from abc import ABC
|
11 |
from collections import OrderedDict, defaultdict, deque
|
@@ -15,13 +15,12 @@ from threading import Thread
|
|
15 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
|
17 |
import torch
|
18 |
-
import torch.nn as nn
|
19 |
import torch.distributed as dist
|
|
|
20 |
import torch.nn.functional as F
|
21 |
import torchvision
|
22 |
from einops import rearrange
|
23 |
from PIL import Image
|
24 |
-
|
25 |
from transformers import (
|
26 |
AutoConfig,
|
27 |
AutoModel,
|
@@ -34,28 +33,30 @@ from transformers import (
|
|
34 |
Qwen2Config,
|
35 |
Qwen2ForCausalLM,
|
36 |
Qwen2PreTrainedModel,
|
37 |
-
TextIteratorStreamer
|
38 |
)
|
39 |
-
from transformers.modeling_utils import ContextManagers, no_init_weights
|
40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
41 |
|
42 |
from .base_projector import MultimodalProjector, MultimodalProjectorConfig
|
43 |
from .builder import build_llm_and_tokenizer
|
44 |
from .configuration_vila import VILAConfig
|
45 |
-
from .
|
46 |
-
from .
|
47 |
-
from .utils import get_model_config
|
48 |
from .media import extract_media
|
|
|
49 |
from .mm_utils import process_image, process_images
|
|
|
50 |
from .tokenizer_utils import tokenize_conversation
|
51 |
-
from .
|
52 |
-
|
53 |
|
54 |
# from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
|
55 |
# quick hack for remote code
|
56 |
def get_pg_manager():
|
57 |
return None
|
58 |
|
|
|
59 |
def get_model_weights_dtype(model: nn.Module):
|
60 |
pass
|
61 |
|
@@ -72,7 +73,77 @@ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> Pre
|
|
72 |
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
|
73 |
mm_projector = MultimodalProjector(mm_projector_cfg, config)
|
74 |
return mm_projector
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
78 |
## skip vision tower instantiation
|
@@ -110,7 +181,7 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
110 |
main_input_name = "input_embeds"
|
111 |
supports_gradient_checkpointing = True
|
112 |
_supports_flash_attn_2 = True
|
113 |
-
|
114 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
115 |
super().__init__(config)
|
116 |
self.config = config
|
@@ -119,22 +190,19 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
119 |
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
|
120 |
else:
|
121 |
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
|
122 |
-
|
123 |
# loading on cpu by default
|
124 |
device_map = kwargs.get("device_map", "cpu")
|
125 |
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
|
126 |
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
|
127 |
if "auto" in device_map or "cuda" in device_map:
|
128 |
self.mm_projector = self.mm_projector.cuda()
|
129 |
-
self.vision_tower = self.vision_tower.cuda()
|
130 |
# set device_map auto can autoamtically shard llm to different devices
|
131 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
132 |
-
|
133 |
-
self.encoders = {
|
134 |
-
|
135 |
-
"video": BasicVideoEncoder(self)
|
136 |
-
}
|
137 |
-
|
138 |
self.post_config()
|
139 |
self.is_loaded = True
|
140 |
|
@@ -143,37 +211,65 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
143 |
), "At least one of the components must be instantiated."
|
144 |
|
145 |
@classmethod
|
146 |
-
def convert_vila_dev_ckpt_to_remote(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
# assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
|
148 |
from huggingface_hub import HfApi, snapshot_download
|
149 |
|
150 |
if os.path.isdir(model_path):
|
151 |
model_path = model_path
|
152 |
api = HfApi()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
if api.repo_exists(model_path):
|
154 |
model_path = snapshot_download(model_path, local_dir=output_dir)
|
155 |
print("downloading HF model to", model_path)
|
156 |
-
|
157 |
cfg_path = os.path.join(model_path, "config.json")
|
158 |
config = json.load(open(cfg_path))
|
159 |
-
config["version"] = "2.0"
|
160 |
config["architectures"] = ["VILAForCasualLM"]
|
161 |
config["auto_map"] = {
|
162 |
"AutoConfig": "modeling_vila.VILAConfig",
|
163 |
"AutoModel": "modeling_vila.VILAForCasualLM",
|
164 |
-
"AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
|
165 |
}
|
166 |
config["model_type"] = "vila"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
json.dump(config, open(cfg_path, "w"), indent=2)
|
168 |
self.copy_remote_py_files(model_path)
|
169 |
-
|
170 |
@classmethod
|
171 |
def copy_remote_py_files(cls, output_dir):
|
172 |
## copy .py and REAMDE for next loading remote code
|
173 |
current_file_path = os.path.abspath(__file__)
|
174 |
current_folder = os.path.dirname(current_file_path)
|
175 |
for file_name in os.listdir(current_folder):
|
176 |
-
if file_name.endswith(".py"):
|
177 |
full_file_name = os.path.join(current_folder, file_name)
|
178 |
if os.path.isfile(full_file_name):
|
179 |
shutil.copy(full_file_name, output_dir)
|
@@ -222,17 +318,15 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
222 |
state_dict=mm_projector_state_dict,
|
223 |
)
|
224 |
self.config.mm_projector_cfg = self.mm_projector.config
|
225 |
-
|
226 |
## update and save top-level config
|
227 |
self.config._name_or_path = output_dir
|
228 |
self.config.architectures = [self.__class__.__name__]
|
229 |
self.config.save_pretrained(output_dir)
|
230 |
-
|
231 |
## copy .py and REAMDE for next loading remote code
|
232 |
self.copy_remote_py_files(output_dir)
|
233 |
|
234 |
-
|
235 |
-
|
236 |
@classmethod
|
237 |
def from_pretrained(
|
238 |
cls,
|
@@ -258,7 +352,7 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
258 |
# variables for XGrammar
|
259 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
260 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
261 |
-
|
262 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
263 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
264 |
# XGrammar tokenizer and grammar compiler
|
@@ -318,11 +412,12 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
318 |
self.get_vision_tower().eval()
|
319 |
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
|
320 |
self.get_mm_projector().eval()
|
321 |
-
|
|
|
322 |
class VILAForCasualLM(VILAPretrainedModel):
|
323 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
324 |
super().__init__(config, *args, **kwargs)
|
325 |
-
|
326 |
def merge_features_for_dynamic_s2(self, image_features, block_sizes):
|
327 |
scales = self.get_vision_tower().scales
|
328 |
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
|
@@ -395,7 +490,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
395 |
if getattr(self.config, "dynamic_s2", False):
|
396 |
image_features = self.get_vision_tower()(images)
|
397 |
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
|
398 |
-
|
399 |
image_features = [
|
400 |
self.split_chessboard(x, block_size[0], block_size[1])
|
401 |
for x, block_size in zip(image_features, new_block_sizes)
|
@@ -881,6 +976,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
881 |
return outputs.logits, labels
|
882 |
|
883 |
return outputs
|
|
|
884 |
@torch.inference_mode()
|
885 |
def generate(
|
886 |
self,
|
@@ -898,7 +994,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
898 |
self,
|
899 |
prompt: Union[str, List],
|
900 |
generation_config: Optional[GenerationConfig] = None,
|
901 |
-
response_format
|
902 |
) -> str:
|
903 |
# TODO(zhijianl): Support directly taking conversation as input
|
904 |
conversation = [{"from": "human", "value": prompt}]
|
|
|
|
|
1 |
import copy
|
2 |
import json
|
3 |
import logging
|
|
|
5 |
import os
|
6 |
import os.path
|
7 |
import os.path as osp
|
8 |
+
import shutil
|
9 |
import warnings
|
10 |
from abc import ABC
|
11 |
from collections import OrderedDict, defaultdict, deque
|
|
|
15 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
|
17 |
import torch
|
|
|
18 |
import torch.distributed as dist
|
19 |
+
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
import torchvision
|
22 |
from einops import rearrange
|
23 |
from PIL import Image
|
|
|
24 |
from transformers import (
|
25 |
AutoConfig,
|
26 |
AutoModel,
|
|
|
33 |
Qwen2Config,
|
34 |
Qwen2ForCausalLM,
|
35 |
Qwen2PreTrainedModel,
|
36 |
+
TextIteratorStreamer,
|
37 |
)
|
|
|
38 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
39 |
+
from transformers.modeling_utils import ContextManagers, no_init_weights
|
40 |
|
41 |
from .base_projector import MultimodalProjector, MultimodalProjectorConfig
|
42 |
from .builder import build_llm_and_tokenizer
|
43 |
from .configuration_vila import VILAConfig
|
44 |
+
from .constants import *
|
45 |
+
from .conversation import SeparatorStyle, default_conversation
|
|
|
46 |
from .media import extract_media
|
47 |
+
from .media_encoder import BasicImageEncoder, BasicVideoEncoder
|
48 |
from .mm_utils import process_image, process_images
|
49 |
+
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
|
50 |
from .tokenizer_utils import tokenize_conversation
|
51 |
+
from .utils import get_model_config
|
52 |
+
|
53 |
|
54 |
# from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
|
55 |
# quick hack for remote code
|
56 |
def get_pg_manager():
|
57 |
return None
|
58 |
|
59 |
+
|
60 |
def get_model_weights_dtype(model: nn.Module):
|
61 |
pass
|
62 |
|
|
|
73 |
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
|
74 |
mm_projector = MultimodalProjector(mm_projector_cfg, config)
|
75 |
return mm_projector
|
76 |
+
|
77 |
+
|
78 |
+
def check_dot_in_model_path(model_path: str):
|
79 |
+
"""Check if the model path contains dot, which will affect the remote code loading."""
|
80 |
+
if osp.isdir(model_path): # local model
|
81 |
+
if "." in osp.abspath(model_path):
|
82 |
+
return True
|
83 |
+
else: # remote model
|
84 |
+
if "." in model_path:
|
85 |
+
return True
|
86 |
+
return False
|
87 |
+
|
88 |
+
|
89 |
+
def get_vila_version(model_path: str) -> str:
|
90 |
+
VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
|
91 |
+
for version in VERSIONS:
|
92 |
+
if version in model_path.lower():
|
93 |
+
return version
|
94 |
+
return None
|
95 |
+
|
96 |
+
|
97 |
+
def generate_jinja_template(conv_mode: str) -> str:
|
98 |
+
if conv_mode == "vicuna_v1":
|
99 |
+
return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." %}
|
100 |
+
{% set roles = ["USER", "ASSISTANT"] %}
|
101 |
+
{% set sep = " " %}
|
102 |
+
{% set sep2 = "</s>" %}
|
103 |
+
|
104 |
+
{{ system_prompt }}
|
105 |
+
|
106 |
+
{% for message in messages %}
|
107 |
+
{% if message['role'] == roles[0] %}
|
108 |
+
{{ roles[0] }}{{ sep }}{{ message['content'] }}{{ sep2 }}
|
109 |
+
{% else %}
|
110 |
+
{{ roles[1] }}{{ sep }}{{ message['content'] }}{{ sep2 }}
|
111 |
+
{% endif %}
|
112 |
+
{% endfor %}"""
|
113 |
+
elif conv_mode == "llama_3":
|
114 |
+
return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." %}
|
115 |
+
{% set roles = ["<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"] %}
|
116 |
+
{% set sep = "<|eot_id|>" %}
|
117 |
+
{% set sep2 = "<|end_of_text|>" %}
|
118 |
+
|
119 |
+
{{ system_prompt }}
|
120 |
+
|
121 |
+
{% for message in messages %}
|
122 |
+
{% if message['role'] == 'user' %}
|
123 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
124 |
+
{% else %}
|
125 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
126 |
+
{% endif %}
|
127 |
+
{% endfor %}
|
128 |
+
|
129 |
+
{{ sep2 }}"""
|
130 |
+
elif conv_mode == "hermes_2":
|
131 |
+
return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
|
132 |
+
{% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
|
133 |
+
{% set sep = "<|im_end|>" %}
|
134 |
+
|
135 |
+
{{ system_prompt }}{{ sep }}
|
136 |
+
|
137 |
+
{% for message in messages %}
|
138 |
+
{% if message['role'] == 'user' %}
|
139 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
140 |
+
{% else %}
|
141 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
142 |
+
{% endif %}
|
143 |
+
{% endfor %}"""
|
144 |
+
else:
|
145 |
+
raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
|
146 |
+
|
147 |
|
148 |
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
149 |
## skip vision tower instantiation
|
|
|
181 |
main_input_name = "input_embeds"
|
182 |
supports_gradient_checkpointing = True
|
183 |
_supports_flash_attn_2 = True
|
184 |
+
|
185 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
186 |
super().__init__(config)
|
187 |
self.config = config
|
|
|
190 |
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
|
191 |
else:
|
192 |
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
|
193 |
+
|
194 |
# loading on cpu by default
|
195 |
device_map = kwargs.get("device_map", "cpu")
|
196 |
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
|
197 |
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
|
198 |
if "auto" in device_map or "cuda" in device_map:
|
199 |
self.mm_projector = self.mm_projector.cuda()
|
200 |
+
self.vision_tower = self.vision_tower.cuda()
|
201 |
# set device_map auto can autoamtically shard llm to different devices
|
202 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
203 |
+
|
204 |
+
self.encoders = {"image": BasicImageEncoder(self), "video": BasicVideoEncoder(self)}
|
205 |
+
|
|
|
|
|
|
|
206 |
self.post_config()
|
207 |
self.is_loaded = True
|
208 |
|
|
|
211 |
), "At least one of the components must be instantiated."
|
212 |
|
213 |
@classmethod
|
214 |
+
def convert_vila_dev_ckpt_to_remote(
|
215 |
+
self,
|
216 |
+
model_path: str,
|
217 |
+
output_dir: str = None,
|
218 |
+
vila_version: str | None = None,
|
219 |
+
conv_mode: str | None = None,
|
220 |
+
*model_args,
|
221 |
+
**kwargs,
|
222 |
+
):
|
223 |
# assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
|
224 |
from huggingface_hub import HfApi, snapshot_download
|
225 |
|
226 |
if os.path.isdir(model_path):
|
227 |
model_path = model_path
|
228 |
api = HfApi()
|
229 |
+
|
230 |
+
if check_dot_in_model_path(model_path) and output_dir is None:
|
231 |
+
raise ValueError(
|
232 |
+
f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
|
233 |
+
)
|
234 |
+
if output_dir is not None and "." in output_dir:
|
235 |
+
raise ValueError(
|
236 |
+
f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
|
237 |
+
)
|
238 |
+
if vila_version is None:
|
239 |
+
vila_version = get_vila_version(model_path)
|
240 |
+
|
241 |
if api.repo_exists(model_path):
|
242 |
model_path = snapshot_download(model_path, local_dir=output_dir)
|
243 |
print("downloading HF model to", model_path)
|
244 |
+
|
245 |
cfg_path = os.path.join(model_path, "config.json")
|
246 |
config = json.load(open(cfg_path))
|
247 |
+
config["version"] = "2.0" # nvila tag
|
248 |
config["architectures"] = ["VILAForCasualLM"]
|
249 |
config["auto_map"] = {
|
250 |
"AutoConfig": "modeling_vila.VILAConfig",
|
251 |
"AutoModel": "modeling_vila.VILAForCasualLM",
|
252 |
+
"AutoModelForCausalLM": "modeling_vila.VILAForCasualLM",
|
253 |
}
|
254 |
config["model_type"] = "vila"
|
255 |
+
if vila_version in ["vila1.5", "vila-m3"]:
|
256 |
+
if conv_mode is None:
|
257 |
+
raise ValueError(f"Please specify the conversation mode for {model_path}.")
|
258 |
+
config["chat_template"] = conv_mode
|
259 |
+
jinja_template = generate_jinja_template(conv_mode)
|
260 |
+
jinja_path = os.path.join(model_path, f"{conv_mode}.jinja")
|
261 |
+
with open(jinja_path, "w") as f:
|
262 |
+
f.write(jinja_template)
|
263 |
json.dump(config, open(cfg_path, "w"), indent=2)
|
264 |
self.copy_remote_py_files(model_path)
|
265 |
+
|
266 |
@classmethod
|
267 |
def copy_remote_py_files(cls, output_dir):
|
268 |
## copy .py and REAMDE for next loading remote code
|
269 |
current_file_path = os.path.abspath(__file__)
|
270 |
current_folder = os.path.dirname(current_file_path)
|
271 |
for file_name in os.listdir(current_folder):
|
272 |
+
if file_name.endswith(".py") or file_name.endswith(".jinja"):
|
273 |
full_file_name = os.path.join(current_folder, file_name)
|
274 |
if os.path.isfile(full_file_name):
|
275 |
shutil.copy(full_file_name, output_dir)
|
|
|
318 |
state_dict=mm_projector_state_dict,
|
319 |
)
|
320 |
self.config.mm_projector_cfg = self.mm_projector.config
|
321 |
+
|
322 |
## update and save top-level config
|
323 |
self.config._name_or_path = output_dir
|
324 |
self.config.architectures = [self.__class__.__name__]
|
325 |
self.config.save_pretrained(output_dir)
|
326 |
+
|
327 |
## copy .py and REAMDE for next loading remote code
|
328 |
self.copy_remote_py_files(output_dir)
|
329 |
|
|
|
|
|
330 |
@classmethod
|
331 |
def from_pretrained(
|
332 |
cls,
|
|
|
352 |
# variables for XGrammar
|
353 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
354 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
355 |
+
|
356 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
357 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
358 |
# XGrammar tokenizer and grammar compiler
|
|
|
412 |
self.get_vision_tower().eval()
|
413 |
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
|
414 |
self.get_mm_projector().eval()
|
415 |
+
|
416 |
+
|
417 |
class VILAForCasualLM(VILAPretrainedModel):
|
418 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
419 |
super().__init__(config, *args, **kwargs)
|
420 |
+
|
421 |
def merge_features_for_dynamic_s2(self, image_features, block_sizes):
|
422 |
scales = self.get_vision_tower().scales
|
423 |
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
|
|
|
490 |
if getattr(self.config, "dynamic_s2", False):
|
491 |
image_features = self.get_vision_tower()(images)
|
492 |
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
|
493 |
+
|
494 |
image_features = [
|
495 |
self.split_chessboard(x, block_size[0], block_size[1])
|
496 |
for x, block_size in zip(image_features, new_block_sizes)
|
|
|
976 |
return outputs.logits, labels
|
977 |
|
978 |
return outputs
|
979 |
+
|
980 |
@torch.inference_mode()
|
981 |
def generate(
|
982 |
self,
|
|
|
994 |
self,
|
995 |
prompt: Union[str, List],
|
996 |
generation_config: Optional[GenerationConfig] = None,
|
997 |
+
response_format=None,
|
998 |
) -> str:
|
999 |
# TODO(zhijianl): Support directly taking conversation as input
|
1000 |
conversation = [{"from": "human", "value": prompt}]
|
siglip_encoder.py
CHANGED
@@ -20,11 +20,11 @@ import torch.nn.functional as F
|
|
20 |
from accelerate.hooks import add_hook_to_module
|
21 |
from einops import rearrange
|
22 |
from s2wrapper import forward as multiscale_forward
|
23 |
-
from transformers import AutoConfig, PreTrainedModel
|
24 |
from transformers.image_processing_utils import BaseImageProcessor
|
25 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
26 |
from transformers.models.siglip import SiglipVisionModel
|
27 |
-
|
28 |
|
29 |
class VisionTower(nn.Module):
|
30 |
def __init__(self, vision_tower, args, delay_load=False):
|
@@ -146,7 +146,6 @@ class VisionTower(nn.Module):
|
|
146 |
|
147 |
return image_features
|
148 |
|
149 |
-
|
150 |
@property
|
151 |
def dummy_feature(self):
|
152 |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
|
20 |
from accelerate.hooks import add_hook_to_module
|
21 |
from einops import rearrange
|
22 |
from s2wrapper import forward as multiscale_forward
|
23 |
+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
|
24 |
from transformers.image_processing_utils import BaseImageProcessor
|
25 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
26 |
from transformers.models.siglip import SiglipVisionModel
|
27 |
+
|
28 |
|
29 |
class VisionTower(nn.Module):
|
30 |
def __init__(self, vision_tower, args, delay_load=False):
|
|
|
146 |
|
147 |
return image_features
|
148 |
|
|
|
149 |
@property
|
150 |
def dummy_feature(self):
|
151 |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
tokenizer_utils.py
CHANGED
@@ -19,9 +19,9 @@ from typing import Any, Dict, List, Optional, Sequence
|
|
19 |
import torch
|
20 |
import transformers
|
21 |
|
22 |
-
from .conversation import default_conversation, SeparatorStyle
|
23 |
-
from .mm_utils import tokenizer_image_token
|
24 |
from .constants import IGNORE_INDEX, SENTINEL_TOKEN
|
|
|
|
|
25 |
|
26 |
# __all__ = [
|
27 |
# "tokenize_conversation",
|
|
|
19 |
import torch
|
20 |
import transformers
|
21 |
|
|
|
|
|
22 |
from .constants import IGNORE_INDEX, SENTINEL_TOKEN
|
23 |
+
from .conversation import SeparatorStyle, default_conversation
|
24 |
+
from .mm_utils import tokenizer_image_token
|
25 |
|
26 |
# __all__ = [
|
27 |
# "tokenize_conversation",
|