import os import os.path as osp from collections import defaultdict from typing import List, Union from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput, VideoInput from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS from .media import Image, Video, extract_media from .mm_utils import process_image, process_images from .tokenizer_utils import tokenize_conversation class VILAProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, }, } class VILAProcessor(ProcessorMixin): # attributes = ["image_processor", "tokenizer"] attributes = [] # valid_kwargs = ["chat_template"] valid_kwargs = [] # image_processor_class = "VILAImageProcessor" # tokenizer_class = ("VILATokenizer", "VILATokenizerFast") def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs): # self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token # self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token self.image_token = MEDIA_TOKENS["image"] self.video_token = MEDIA_TOKENS["video"] self.config = config self.image_processor = image_processor self.tokenizer = tokenizer super().__init__(image_processor, tokenizer, chat_template=chat_template) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): if os.path.isdir(pretrained_model_name_or_path): pretrained_model_name_or_path = pretrained_model_name_or_path else: print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading") from huggingface_hub import HfApi, snapshot_download pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) image_processor = AutoImageProcessor.from_pretrained( osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True ) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) return cls(image_processor=image_processor, tokenizer=tokenizer, config=config) def __repr__(self): return ( f"VILAProcessor(image_processor={self.image_processor}, tokenizer={self.tokenizer}, config={self.config})" ) def __call__( self, conversation, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, videos: VideoInput = None, **kwargs: Unpack[VILAProcessorKwargs], ) -> BatchFeature: # TODO: should be merged with llava_arch.py/generate_content() # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used) media = extract_media(conversation, self.config) # Process media media_config = defaultdict(dict) for name in media: if name == "image": if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]: self.config.image_processor = self.image_processor if self.config.image_aspect_ratio == "dynamic": images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half() conversation[0]["value"] = conversation[0]["value"].replace( DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] ) else: if type(self.config.s2_scales) is str: self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) images, block_sizes = process_image( media["image"][0], self.config, None, enable_dynamic_s2=True ) images = images.half() media_config[name]["block_sizes"] = [block_sizes] else: images = process_images(media["image"], self.vision_tower.image_processor, self.config).half() media[name] = [image for image in images] elif name == "video": media[name] = [ process_images(images, self.vision_tower.image_processor, self.config).half() for images in media[name] ] else: raise ValueError(f"Unsupported media type: {name}") input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0) # Set up the generation config # print(input_ids.shape); print(media); input() return BatchFeature(data={"input_ids": input_ids, **media}) def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) def post_process_image_text_to_text(self, generated_outputs): """ Post-process the output of the model to decode the text. Args: generated_outputs (`torch.Tensor` or `np.ndarray`): The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` or `(sequence_length,)`. Returns: `List[str]`: The decoded text. """ return self.tokenizer.batch_decode( generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False ) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) # inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt") def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs): vila_conv = [] for chat in conversation: vila_chat = {"from": "", "value": []} if chat["role"] == "user": # user allows to input image and text vila_chat["from"] = "human" for content in chat["content"]: if content["type"] == "image": vila_chat["value"].append(Image(content["path"])) elif content["type"] == "text": vila_chat["value"].append(content["text"]) else: raise ValueError(f"Unsupported content type: {content['type']}") elif chat["role"] == "assistant": vila_chat["from"] = "gpt" for content in chat["content"]: assert content["type"] == "text", f"Unsupported content type: {content['type']}" vila_chat["value"].append(content["text"]) vila_conv.append(vila_chat) return self(vila_conv) if __name__ == "__main__": # gpt style: user, assistant # vila style: human, gpt gpt_conv = [ { "role": "user", "content": [ {"type": "image", "path": "demo_images/demo_img_1.png"}, {"type": "text", "text": "Describe this image."}, ], } ] llavaconv = [ { "from": "human", "value": [ PIL.Image.open("demo_images/demo_img_1.png"), "Describe this image.", ], } ] processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True) inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt") # model = llava.load("Efficient-Large-Model/qwen25_2B_3x3-sft").cuda() # print(model) model_path = "NVILA-Lite-2B-hf-preview" model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto") # res = model.generate_content(["how are you today?"]) # print(model.config) # print(model.tokenizer) # print(res) # exit(0) processor = VILAProcessor( config=model.config, image_processor=model.vision_tower.image_processor, tokenizer=model.tokenizer, ) # TODO: add padding, return_tensors, inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt") print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image]) print("vila conv pass") inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt") print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image]) print("gpt conv pass") output_ids = model.generate( input_ids=inputs.input_ids, media={ "image": inputs.image, }, media_config={"image": {}}, generation_config=model.generation_config, max_new_tokens=100, ) print(output_ids)