Ligeng-Zhu commited on
Commit
e352a62
·
verified ·
1 Parent(s): 6bce680

Upload files with `vila-upload`.

Browse files

Upload media_encoder.py
Upload media.py
Upload modeling_vila.py
Upload configuration_vila.py
Upload builder.py
Upload mm_utils.py
Upload tokenizer_utils.py
Upload siglip_encoder.py

builder.py CHANGED
@@ -22,9 +22,9 @@ from dataclasses import asdict
22
  from typing import Any, Dict, List, Optional, Sequence, Tuple
23
 
24
  import torch
 
25
  from huggingface_hub import file_exists, repo_exists
26
  from huggingface_hub.utils import HFValidationError
27
- import transformers
28
  from transformers import (
29
  AutoConfig,
30
  AutoModelForCausalLM,
@@ -33,8 +33,9 @@ from transformers import (
33
  PreTrainedModel,
34
  PreTrainedTokenizer,
35
  )
 
36
  # from .conversation import *
37
- from .conversation import default_conversation, SeparatorStyle
38
 
39
  SENTINEL_TOKEN = "<vila/sentinel>"
40
  MEDIA_TOKENS = {
@@ -51,9 +52,11 @@ DUMMY_CONVERSATION = [
51
  {"from": "gpt", "value": "answer"},
52
  ] * 10
53
 
 
54
  def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
55
  return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
56
-
 
57
  def has_tokenizer(repo_id_or_path: str) -> bool:
58
  # Check if the tokenizer is in a local directory
59
  if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
@@ -65,12 +68,14 @@ def has_tokenizer(repo_id_or_path: str) -> bool:
65
  except HFValidationError:
66
  return False
67
 
 
68
  def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
69
  if not hasattr(tokenizer, "sentinel_token"):
70
  tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
71
  tokenizer.sentinel_token = SENTINEL_TOKEN
72
  tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
73
 
 
74
  def tokenize_conversation_legacy(
75
  messages: Sequence[Dict[str, str]],
76
  tokenizer: transformers.PreTrainedTokenizer,
@@ -103,6 +108,7 @@ def tokenize_conversation_legacy(
103
 
104
  return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
105
 
 
106
  def tokenize_conversation(
107
  messages: Sequence[Dict[str, str]],
108
  tokenizer: transformers.PreTrainedTokenizer,
@@ -148,6 +154,7 @@ def tokenize_conversation(
148
  )
149
  return tokenizer_image_token(text, tokenizer, return_tensors="pt")
150
 
 
151
  def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
152
  _maybe_add_sentinel_token(tokenizer)
153
  template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
@@ -159,6 +166,7 @@ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
159
  stop_tokens.add(stop_token)
160
  return list(stop_tokens)
161
 
 
162
  def context_length_extension(config):
163
  orig_ctx_len = getattr(config, "max_position_embeddings", None)
164
  model_max_length = getattr(config, "model_max_length", None)
@@ -186,7 +194,7 @@ def build_llm_and_tokenizer(
186
 
187
  # Quantization related
188
  quantization_restore_from_checkpoint = False
189
-
190
  if quantization_restore_from_checkpoint:
191
  fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
192
 
@@ -215,6 +223,8 @@ def build_llm_and_tokenizer(
215
  if getattr(config, "chat_template", None) is not None:
216
  print(f"Using chat template: {config.chat_template}")
217
  fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
 
 
218
  with open(fpath) as fd:
219
  chat_template = fd.read()
220
  tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
 
22
  from typing import Any, Dict, List, Optional, Sequence, Tuple
23
 
24
  import torch
25
+ import transformers
26
  from huggingface_hub import file_exists, repo_exists
27
  from huggingface_hub.utils import HFValidationError
 
28
  from transformers import (
29
  AutoConfig,
30
  AutoModelForCausalLM,
 
33
  PreTrainedModel,
34
  PreTrainedTokenizer,
35
  )
36
+
37
  # from .conversation import *
38
+ from .conversation import SeparatorStyle, default_conversation
39
 
40
  SENTINEL_TOKEN = "<vila/sentinel>"
41
  MEDIA_TOKENS = {
 
52
  {"from": "gpt", "value": "answer"},
53
  ] * 10
54
 
55
+
56
  def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
57
  return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
58
+
59
+
60
  def has_tokenizer(repo_id_or_path: str) -> bool:
61
  # Check if the tokenizer is in a local directory
62
  if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
 
68
  except HFValidationError:
69
  return False
70
 
71
+
72
  def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
73
  if not hasattr(tokenizer, "sentinel_token"):
74
  tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
75
  tokenizer.sentinel_token = SENTINEL_TOKEN
76
  tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
77
 
78
+
79
  def tokenize_conversation_legacy(
80
  messages: Sequence[Dict[str, str]],
81
  tokenizer: transformers.PreTrainedTokenizer,
 
108
 
109
  return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
110
 
111
+
112
  def tokenize_conversation(
113
  messages: Sequence[Dict[str, str]],
114
  tokenizer: transformers.PreTrainedTokenizer,
 
154
  )
155
  return tokenizer_image_token(text, tokenizer, return_tensors="pt")
156
 
157
+
158
  def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
159
  _maybe_add_sentinel_token(tokenizer)
160
  template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
 
166
  stop_tokens.add(stop_token)
167
  return list(stop_tokens)
168
 
169
+
170
  def context_length_extension(config):
171
  orig_ctx_len = getattr(config, "max_position_embeddings", None)
172
  model_max_length = getattr(config, "model_max_length", None)
 
194
 
195
  # Quantization related
196
  quantization_restore_from_checkpoint = False
197
+
198
  if quantization_restore_from_checkpoint:
199
  fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
200
 
 
223
  if getattr(config, "chat_template", None) is not None:
224
  print(f"Using chat template: {config.chat_template}")
225
  fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
226
+ if not os.path.exists(fpath):
227
+ fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
228
  with open(fpath) as fd:
229
  chat_template = fd.read()
230
  tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
configuration_vila.py CHANGED
@@ -1,15 +1,24 @@
 
1
  import math
 
 
 
 
2
  from typing import List, Optional
3
- import json
4
  import torch
5
  import torchvision
6
- import os, os.path as osp
7
-
8
- from threading import Thread
9
- from copy import deepcopy
10
  from PIL import Image
11
- from transformers import Qwen2Config, PretrainedConfig, PreTrainedModel
12
- from transformers import AutoProcessor, Qwen2PreTrainedModel, Qwen2ForCausalLM, TextIteratorStreamer
 
 
 
 
 
 
 
 
13
 
14
  class VILAConfig(PretrainedConfig):
15
  model_type = "vila"
@@ -82,4 +91,3 @@ class VILAConfig(PretrainedConfig):
82
  self.video_encoder = video_encoder
83
 
84
  super().__init__(**kwargs)
85
-
 
1
+ import json
2
  import math
3
+ import os
4
+ import os.path as osp
5
+ from copy import deepcopy
6
+ from threading import Thread
7
  from typing import List, Optional
8
+
9
  import torch
10
  import torchvision
 
 
 
 
11
  from PIL import Image
12
+ from transformers import (
13
+ AutoProcessor,
14
+ PretrainedConfig,
15
+ PreTrainedModel,
16
+ Qwen2Config,
17
+ Qwen2ForCausalLM,
18
+ Qwen2PreTrainedModel,
19
+ TextIteratorStreamer,
20
+ )
21
+
22
 
23
  class VILAConfig(PretrainedConfig):
24
  model_type = "vila"
 
91
  self.video_encoder = video_encoder
92
 
93
  super().__init__(**kwargs)
 
media.py CHANGED
@@ -20,13 +20,16 @@ MEDIA_TOKENS = {
20
  "video": "<vila/video>",
21
  }
22
 
 
23
  class Media:
24
  pass
25
 
 
26
  class File(Media):
27
  def __init__(self, path: str) -> None:
28
  self.path = path
29
 
 
30
  class Image(File):
31
  pass
32
 
@@ -34,6 +37,7 @@ class Image(File):
34
  class Video(File):
35
  pass
36
 
 
37
  def make_list(obj: Any) -> List:
38
  return obj if isinstance(obj, list) else [obj]
39
 
 
20
  "video": "<vila/video>",
21
  }
22
 
23
+
24
  class Media:
25
  pass
26
 
27
+
28
  class File(Media):
29
  def __init__(self, path: str) -> None:
30
  self.path = path
31
 
32
+
33
  class Image(File):
34
  pass
35
 
 
37
  class Video(File):
38
  pass
39
 
40
+
41
  def make_list(obj: Any) -> List:
42
  return obj if isinstance(obj, list) else [obj]
43
 
media_encoder.py CHANGED
@@ -1,8 +1,9 @@
1
- import torch
2
- from torch import nn
3
  from functools import partial
4
  from typing import Any, Dict, List, Optional
5
 
 
 
 
6
 
7
  class BaseEncoder(nn.Module):
8
  def __init__(self, parent: nn.Module) -> None:
 
 
 
1
  from functools import partial
2
  from typing import Any, Dict, List, Optional
3
 
4
+ import torch
5
+ from torch import nn
6
+
7
 
8
  class BaseEncoder(nn.Module):
9
  def __init__(self, parent: nn.Module) -> None:
mm_utils.py CHANGED
@@ -26,7 +26,7 @@ import torch
26
  from PIL import Image
27
  from transformers import StoppingCriteria
28
 
29
- from llava.constants import DEFAULT_IMAGE_TOKEN
30
 
31
 
32
  def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
 
26
  from PIL import Image
27
  from transformers import StoppingCriteria
28
 
29
+ from .constants import DEFAULT_IMAGE_TOKEN
30
 
31
 
32
  def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
modeling_vila.py CHANGED
@@ -1,4 +1,3 @@
1
- import shutil
2
  import copy
3
  import json
4
  import logging
@@ -6,6 +5,7 @@ import math
6
  import os
7
  import os.path
8
  import os.path as osp
 
9
  import warnings
10
  from abc import ABC
11
  from collections import OrderedDict, defaultdict, deque
@@ -15,13 +15,12 @@ from threading import Thread
15
  from typing import Any, Dict, List, Optional, Tuple, Union
16
 
17
  import torch
18
- import torch.nn as nn
19
  import torch.distributed as dist
 
20
  import torch.nn.functional as F
21
  import torchvision
22
  from einops import rearrange
23
  from PIL import Image
24
-
25
  from transformers import (
26
  AutoConfig,
27
  AutoModel,
@@ -34,28 +33,30 @@ from transformers import (
34
  Qwen2Config,
35
  Qwen2ForCausalLM,
36
  Qwen2PreTrainedModel,
37
- TextIteratorStreamer
38
  )
39
- from transformers.modeling_utils import ContextManagers, no_init_weights
40
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
41
 
42
  from .base_projector import MultimodalProjector, MultimodalProjectorConfig
43
  from .builder import build_llm_and_tokenizer
44
  from .configuration_vila import VILAConfig
45
- from .media_encoder import BasicImageEncoder, BasicVideoEncoder
46
- from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
47
- from .utils import get_model_config
48
  from .media import extract_media
 
49
  from .mm_utils import process_image, process_images
 
50
  from .tokenizer_utils import tokenize_conversation
51
- from .constants import *
52
- from .conversation import default_conversation, SeparatorStyle
53
 
54
  # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
55
  # quick hack for remote code
56
  def get_pg_manager():
57
  return None
58
 
 
59
  def get_model_weights_dtype(model: nn.Module):
60
  pass
61
 
@@ -72,7 +73,77 @@ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> Pre
72
  mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
73
  mm_projector = MultimodalProjector(mm_projector_cfg, config)
74
  return mm_projector
75
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
78
  ## skip vision tower instantiation
@@ -110,7 +181,7 @@ class VILAPretrainedModel(PreTrainedModel):
110
  main_input_name = "input_embeds"
111
  supports_gradient_checkpointing = True
112
  _supports_flash_attn_2 = True
113
-
114
  def __init__(self, config: VILAConfig, *args, **kwargs):
115
  super().__init__(config)
116
  self.config = config
@@ -119,22 +190,19 @@ class VILAPretrainedModel(PreTrainedModel):
119
  llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
120
  else:
121
  raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
122
-
123
  # loading on cpu by default
124
  device_map = kwargs.get("device_map", "cpu")
125
  self.mm_projector = build_mm_projector(mm_projector_cfg, config)
126
  self.vision_tower = build_vision_tower(vision_tower_cfg, config)
127
  if "auto" in device_map or "cuda" in device_map:
128
  self.mm_projector = self.mm_projector.cuda()
129
- self.vision_tower = self.vision_tower.cuda()
130
  # set device_map auto can autoamtically shard llm to different devices
131
  self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
132
-
133
- self.encoders = {
134
- "image": BasicImageEncoder(self),
135
- "video": BasicVideoEncoder(self)
136
- }
137
-
138
  self.post_config()
139
  self.is_loaded = True
140
 
@@ -143,37 +211,65 @@ class VILAPretrainedModel(PreTrainedModel):
143
  ), "At least one of the components must be instantiated."
144
 
145
  @classmethod
146
- def convert_vila_dev_ckpt_to_remote(self, model_path: str, output_dir:str = None, *model_args, **kwargs):
 
 
 
 
 
 
 
 
147
  # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
148
  from huggingface_hub import HfApi, snapshot_download
149
 
150
  if os.path.isdir(model_path):
151
  model_path = model_path
152
  api = HfApi()
 
 
 
 
 
 
 
 
 
 
 
 
153
  if api.repo_exists(model_path):
154
  model_path = snapshot_download(model_path, local_dir=output_dir)
155
  print("downloading HF model to", model_path)
156
-
157
  cfg_path = os.path.join(model_path, "config.json")
158
  config = json.load(open(cfg_path))
159
- config["version"] = "2.0" # nvila tag
160
  config["architectures"] = ["VILAForCasualLM"]
161
  config["auto_map"] = {
162
  "AutoConfig": "modeling_vila.VILAConfig",
163
  "AutoModel": "modeling_vila.VILAForCasualLM",
164
- "AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
165
  }
166
  config["model_type"] = "vila"
 
 
 
 
 
 
 
 
167
  json.dump(config, open(cfg_path, "w"), indent=2)
168
  self.copy_remote_py_files(model_path)
169
-
170
  @classmethod
171
  def copy_remote_py_files(cls, output_dir):
172
  ## copy .py and REAMDE for next loading remote code
173
  current_file_path = os.path.abspath(__file__)
174
  current_folder = os.path.dirname(current_file_path)
175
  for file_name in os.listdir(current_folder):
176
- if file_name.endswith(".py"):
177
  full_file_name = os.path.join(current_folder, file_name)
178
  if os.path.isfile(full_file_name):
179
  shutil.copy(full_file_name, output_dir)
@@ -222,17 +318,15 @@ class VILAPretrainedModel(PreTrainedModel):
222
  state_dict=mm_projector_state_dict,
223
  )
224
  self.config.mm_projector_cfg = self.mm_projector.config
225
-
226
  ## update and save top-level config
227
  self.config._name_or_path = output_dir
228
  self.config.architectures = [self.__class__.__name__]
229
  self.config.save_pretrained(output_dir)
230
-
231
  ## copy .py and REAMDE for next loading remote code
232
  self.copy_remote_py_files(output_dir)
233
 
234
-
235
-
236
  @classmethod
237
  def from_pretrained(
238
  cls,
@@ -258,7 +352,7 @@ class VILAPretrainedModel(PreTrainedModel):
258
  # variables for XGrammar
259
  # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
260
  NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
261
-
262
  # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
263
  self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
264
  # XGrammar tokenizer and grammar compiler
@@ -318,11 +412,12 @@ class VILAPretrainedModel(PreTrainedModel):
318
  self.get_vision_tower().eval()
319
  if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
320
  self.get_mm_projector().eval()
321
-
 
322
  class VILAForCasualLM(VILAPretrainedModel):
323
  def __init__(self, config: VILAConfig, *args, **kwargs):
324
  super().__init__(config, *args, **kwargs)
325
-
326
  def merge_features_for_dynamic_s2(self, image_features, block_sizes):
327
  scales = self.get_vision_tower().scales
328
  resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
@@ -395,7 +490,7 @@ class VILAForCasualLM(VILAPretrainedModel):
395
  if getattr(self.config, "dynamic_s2", False):
396
  image_features = self.get_vision_tower()(images)
397
  image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
398
-
399
  image_features = [
400
  self.split_chessboard(x, block_size[0], block_size[1])
401
  for x, block_size in zip(image_features, new_block_sizes)
@@ -881,6 +976,7 @@ class VILAForCasualLM(VILAPretrainedModel):
881
  return outputs.logits, labels
882
 
883
  return outputs
 
884
  @torch.inference_mode()
885
  def generate(
886
  self,
@@ -898,7 +994,7 @@ class VILAForCasualLM(VILAPretrainedModel):
898
  self,
899
  prompt: Union[str, List],
900
  generation_config: Optional[GenerationConfig] = None,
901
- response_format = None,
902
  ) -> str:
903
  # TODO(zhijianl): Support directly taking conversation as input
904
  conversation = [{"from": "human", "value": prompt}]
 
 
1
  import copy
2
  import json
3
  import logging
 
5
  import os
6
  import os.path
7
  import os.path as osp
8
+ import shutil
9
  import warnings
10
  from abc import ABC
11
  from collections import OrderedDict, defaultdict, deque
 
15
  from typing import Any, Dict, List, Optional, Tuple, Union
16
 
17
  import torch
 
18
  import torch.distributed as dist
19
+ import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torchvision
22
  from einops import rearrange
23
  from PIL import Image
 
24
  from transformers import (
25
  AutoConfig,
26
  AutoModel,
 
33
  Qwen2Config,
34
  Qwen2ForCausalLM,
35
  Qwen2PreTrainedModel,
36
+ TextIteratorStreamer,
37
  )
 
38
  from transformers.modeling_outputs import CausalLMOutputWithPast
39
+ from transformers.modeling_utils import ContextManagers, no_init_weights
40
 
41
  from .base_projector import MultimodalProjector, MultimodalProjectorConfig
42
  from .builder import build_llm_and_tokenizer
43
  from .configuration_vila import VILAConfig
44
+ from .constants import *
45
+ from .conversation import SeparatorStyle, default_conversation
 
46
  from .media import extract_media
47
+ from .media_encoder import BasicImageEncoder, BasicVideoEncoder
48
  from .mm_utils import process_image, process_images
49
+ from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
50
  from .tokenizer_utils import tokenize_conversation
51
+ from .utils import get_model_config
52
+
53
 
54
  # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
55
  # quick hack for remote code
56
  def get_pg_manager():
57
  return None
58
 
59
+
60
  def get_model_weights_dtype(model: nn.Module):
61
  pass
62
 
 
73
  mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
74
  mm_projector = MultimodalProjector(mm_projector_cfg, config)
75
  return mm_projector
76
+
77
+
78
+ def check_dot_in_model_path(model_path: str):
79
+ """Check if the model path contains dot, which will affect the remote code loading."""
80
+ if osp.isdir(model_path): # local model
81
+ if "." in osp.abspath(model_path):
82
+ return True
83
+ else: # remote model
84
+ if "." in model_path:
85
+ return True
86
+ return False
87
+
88
+
89
+ def get_vila_version(model_path: str) -> str:
90
+ VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
91
+ for version in VERSIONS:
92
+ if version in model_path.lower():
93
+ return version
94
+ return None
95
+
96
+
97
+ def generate_jinja_template(conv_mode: str) -> str:
98
+ if conv_mode == "vicuna_v1":
99
+ return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." %}
100
+ {% set roles = ["USER", "ASSISTANT"] %}
101
+ {% set sep = " " %}
102
+ {% set sep2 = "</s>" %}
103
+
104
+ {{ system_prompt }}
105
+
106
+ {% for message in messages %}
107
+ {% if message['role'] == roles[0] %}
108
+ {{ roles[0] }}{{ sep }}{{ message['content'] }}{{ sep2 }}
109
+ {% else %}
110
+ {{ roles[1] }}{{ sep }}{{ message['content'] }}{{ sep2 }}
111
+ {% endif %}
112
+ {% endfor %}"""
113
+ elif conv_mode == "llama_3":
114
+ return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." %}
115
+ {% set roles = ["<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"] %}
116
+ {% set sep = "<|eot_id|>" %}
117
+ {% set sep2 = "<|end_of_text|>" %}
118
+
119
+ {{ system_prompt }}
120
+
121
+ {% for message in messages %}
122
+ {% if message['role'] == 'user' %}
123
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
124
+ {% else %}
125
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
126
+ {% endif %}
127
+ {% endfor %}
128
+
129
+ {{ sep2 }}"""
130
+ elif conv_mode == "hermes_2":
131
+ return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
132
+ {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
133
+ {% set sep = "<|im_end|>" %}
134
+
135
+ {{ system_prompt }}{{ sep }}
136
+
137
+ {% for message in messages %}
138
+ {% if message['role'] == 'user' %}
139
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
140
+ {% else %}
141
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
142
+ {% endif %}
143
+ {% endfor %}"""
144
+ else:
145
+ raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
146
+
147
 
148
  def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
149
  ## skip vision tower instantiation
 
181
  main_input_name = "input_embeds"
182
  supports_gradient_checkpointing = True
183
  _supports_flash_attn_2 = True
184
+
185
  def __init__(self, config: VILAConfig, *args, **kwargs):
186
  super().__init__(config)
187
  self.config = config
 
190
  llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
191
  else:
192
  raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
193
+
194
  # loading on cpu by default
195
  device_map = kwargs.get("device_map", "cpu")
196
  self.mm_projector = build_mm_projector(mm_projector_cfg, config)
197
  self.vision_tower = build_vision_tower(vision_tower_cfg, config)
198
  if "auto" in device_map or "cuda" in device_map:
199
  self.mm_projector = self.mm_projector.cuda()
200
+ self.vision_tower = self.vision_tower.cuda()
201
  # set device_map auto can autoamtically shard llm to different devices
202
  self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
203
+
204
+ self.encoders = {"image": BasicImageEncoder(self), "video": BasicVideoEncoder(self)}
205
+
 
 
 
206
  self.post_config()
207
  self.is_loaded = True
208
 
 
211
  ), "At least one of the components must be instantiated."
212
 
213
  @classmethod
214
+ def convert_vila_dev_ckpt_to_remote(
215
+ self,
216
+ model_path: str,
217
+ output_dir: str = None,
218
+ vila_version: str | None = None,
219
+ conv_mode: str | None = None,
220
+ *model_args,
221
+ **kwargs,
222
+ ):
223
  # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
224
  from huggingface_hub import HfApi, snapshot_download
225
 
226
  if os.path.isdir(model_path):
227
  model_path = model_path
228
  api = HfApi()
229
+
230
+ if check_dot_in_model_path(model_path) and output_dir is None:
231
+ raise ValueError(
232
+ f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
233
+ )
234
+ if output_dir is not None and "." in output_dir:
235
+ raise ValueError(
236
+ f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
237
+ )
238
+ if vila_version is None:
239
+ vila_version = get_vila_version(model_path)
240
+
241
  if api.repo_exists(model_path):
242
  model_path = snapshot_download(model_path, local_dir=output_dir)
243
  print("downloading HF model to", model_path)
244
+
245
  cfg_path = os.path.join(model_path, "config.json")
246
  config = json.load(open(cfg_path))
247
+ config["version"] = "2.0" # nvila tag
248
  config["architectures"] = ["VILAForCasualLM"]
249
  config["auto_map"] = {
250
  "AutoConfig": "modeling_vila.VILAConfig",
251
  "AutoModel": "modeling_vila.VILAForCasualLM",
252
+ "AutoModelForCausalLM": "modeling_vila.VILAForCasualLM",
253
  }
254
  config["model_type"] = "vila"
255
+ if vila_version in ["vila1.5", "vila-m3"]:
256
+ if conv_mode is None:
257
+ raise ValueError(f"Please specify the conversation mode for {model_path}.")
258
+ config["chat_template"] = conv_mode
259
+ jinja_template = generate_jinja_template(conv_mode)
260
+ jinja_path = os.path.join(model_path, f"{conv_mode}.jinja")
261
+ with open(jinja_path, "w") as f:
262
+ f.write(jinja_template)
263
  json.dump(config, open(cfg_path, "w"), indent=2)
264
  self.copy_remote_py_files(model_path)
265
+
266
  @classmethod
267
  def copy_remote_py_files(cls, output_dir):
268
  ## copy .py and REAMDE for next loading remote code
269
  current_file_path = os.path.abspath(__file__)
270
  current_folder = os.path.dirname(current_file_path)
271
  for file_name in os.listdir(current_folder):
272
+ if file_name.endswith(".py") or file_name.endswith(".jinja"):
273
  full_file_name = os.path.join(current_folder, file_name)
274
  if os.path.isfile(full_file_name):
275
  shutil.copy(full_file_name, output_dir)
 
318
  state_dict=mm_projector_state_dict,
319
  )
320
  self.config.mm_projector_cfg = self.mm_projector.config
321
+
322
  ## update and save top-level config
323
  self.config._name_or_path = output_dir
324
  self.config.architectures = [self.__class__.__name__]
325
  self.config.save_pretrained(output_dir)
326
+
327
  ## copy .py and REAMDE for next loading remote code
328
  self.copy_remote_py_files(output_dir)
329
 
 
 
330
  @classmethod
331
  def from_pretrained(
332
  cls,
 
352
  # variables for XGrammar
353
  # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
354
  NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
355
+
356
  # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
357
  self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
358
  # XGrammar tokenizer and grammar compiler
 
412
  self.get_vision_tower().eval()
413
  if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
414
  self.get_mm_projector().eval()
415
+
416
+
417
  class VILAForCasualLM(VILAPretrainedModel):
418
  def __init__(self, config: VILAConfig, *args, **kwargs):
419
  super().__init__(config, *args, **kwargs)
420
+
421
  def merge_features_for_dynamic_s2(self, image_features, block_sizes):
422
  scales = self.get_vision_tower().scales
423
  resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
 
490
  if getattr(self.config, "dynamic_s2", False):
491
  image_features = self.get_vision_tower()(images)
492
  image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
493
+
494
  image_features = [
495
  self.split_chessboard(x, block_size[0], block_size[1])
496
  for x, block_size in zip(image_features, new_block_sizes)
 
976
  return outputs.logits, labels
977
 
978
  return outputs
979
+
980
  @torch.inference_mode()
981
  def generate(
982
  self,
 
994
  self,
995
  prompt: Union[str, List],
996
  generation_config: Optional[GenerationConfig] = None,
997
+ response_format=None,
998
  ) -> str:
999
  # TODO(zhijianl): Support directly taking conversation as input
1000
  conversation = [{"from": "human", "value": prompt}]
siglip_encoder.py CHANGED
@@ -20,11 +20,11 @@ import torch.nn.functional as F
20
  from accelerate.hooks import add_hook_to_module
21
  from einops import rearrange
22
  from s2wrapper import forward as multiscale_forward
23
- from transformers import AutoConfig, PreTrainedModel
24
  from transformers.image_processing_utils import BaseImageProcessor
25
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
26
  from transformers.models.siglip import SiglipVisionModel
27
- from transformers import PretrainedConfig, SiglipImageProcessor
28
 
29
  class VisionTower(nn.Module):
30
  def __init__(self, vision_tower, args, delay_load=False):
@@ -146,7 +146,6 @@ class VisionTower(nn.Module):
146
 
147
  return image_features
148
 
149
-
150
  @property
151
  def dummy_feature(self):
152
  return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
 
20
  from accelerate.hooks import add_hook_to_module
21
  from einops import rearrange
22
  from s2wrapper import forward as multiscale_forward
23
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
24
  from transformers.image_processing_utils import BaseImageProcessor
25
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
26
  from transformers.models.siglip import SiglipVisionModel
27
+
28
 
29
  class VisionTower(nn.Module):
30
  def __init__(self, vision_tower, args, delay_load=False):
 
146
 
147
  return image_features
148
 
 
149
  @property
150
  def dummy_feature(self):
151
  return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
tokenizer_utils.py CHANGED
@@ -19,9 +19,9 @@ from typing import Any, Dict, List, Optional, Sequence
19
  import torch
20
  import transformers
21
 
22
- from .conversation import default_conversation, SeparatorStyle
23
- from .mm_utils import tokenizer_image_token
24
  from .constants import IGNORE_INDEX, SENTINEL_TOKEN
 
 
25
 
26
  # __all__ = [
27
  # "tokenize_conversation",
 
19
  import torch
20
  import transformers
21
 
 
 
22
  from .constants import IGNORE_INDEX, SENTINEL_TOKEN
23
+ from .conversation import SeparatorStyle, default_conversation
24
+ from .mm_utils import tokenizer_image_token
25
 
26
  # __all__ = [
27
  # "tokenize_conversation",