# -------------------------------------------------------- # InternVL # Copyright (c) 2024 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import io import matplotlib.pyplot as plt from transformers.trainer_pt_utils import LabelSmoother IGNORE_TOKEN_ID = LabelSmoother.ignore_index import os import random import re from collections import Counter from typing import Dict import cv2 import imageio import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as T import transformers from decord import VideoReader from internvl.conversation import get_conv_template from PIL import Image from torch.utils.data import ConcatDataset, WeightedRandomSampler from torchvision.transforms.functional import InterpolationMode from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD, IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, SIGLIP_MEAN, SIGLIP_STD) try: from petrel_client.client import Client from petrel_client.common.config import Config except ImportError as E: print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') import sys def calculate_ngram_repetition(text, n): words = text.split() ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] ngram_counts = Counter(ngrams) total_ngrams = len(ngrams) repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1) return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0 def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10): for conversation in conversations: if conversation['from'] == 'gpt': model_answer = conversation['value'] repeat_ratio = calculate_ngram_repetition(model_answer, ngram) if repeat_ratio > repeat_threshold: raise Exception def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): if sample in ['rand', 'middle']: # uniform sampling acc_samples = min(num_frames, vlen) # split the video into `acc_samples` intervals, and sample from each interval. intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) ranges = [] for idx, interv in enumerate(intervals[:-1]): ranges.append((interv, intervals[idx + 1] - 1)) if sample == 'rand': try: frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] except: frame_indices = np.random.permutation(vlen)[:acc_samples] frame_indices.sort() frame_indices = list(frame_indices) elif fix_start is not None: frame_indices = [x[0] + fix_start for x in ranges] elif sample == 'middle': frame_indices = [(x[0] + x[1]) // 2 for x in ranges] else: raise NotImplementedError if len(frame_indices) < num_frames: # padded with last frame padded_frame_indices = [frame_indices[-1]] * num_frames padded_frame_indices[:len(frame_indices)] = frame_indices frame_indices = padded_frame_indices elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps output_fps = float(sample[3:]) duration = float(vlen) / input_fps delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) frame_indices = np.around(frame_seconds * input_fps).astype(int) frame_indices = [e for e in frame_indices if e < vlen] if max_num_frames > 0 and len(frame_indices) > max_num_frames: frame_indices = frame_indices[:max_num_frames] # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) else: raise ValueError return frame_indices def read_frames_gif( video_path, num_frames, sample='rand', fix_start=None, client=None, min_num_frames=4 ): if 's3://' in video_path: video_bytes = client.get(video_path) gif = imageio.get_reader(io.BytesIO(video_bytes)) else: gif = imageio.get_reader(video_path) vlen = len(gif) t_num_frames = np.random.randint(min_num_frames, num_frames + 1) frame_indices = get_frame_indices( t_num_frames, vlen, sample=sample, fix_start=fix_start ) frames = [] for index, frame in enumerate(gif): if index in frame_indices: frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8) frame = Image.fromarray(frame) frames.append(frame) return frames def read_frames_decord( video_path, num_frames, sample='rand', fix_start=None, client=None, clip=None, min_num_frames=4 ): if 's3://' in video_path: video_bytes = client.get(video_path) video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) else: video_reader = VideoReader(video_path, num_threads=1) vlen = len(video_reader) fps = video_reader.get_avg_fps() duration = vlen / float(fps) if clip: start, end = clip duration = end - start vlen = int(duration * fps) start_index = int(start * fps) # t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames) t_num_frames = np.random.randint(min_num_frames, num_frames + 1) frame_indices = get_frame_indices( t_num_frames, vlen, sample=sample, fix_start=fix_start, input_fps=fps ) if clip: frame_indices = [f + start_index for f in frame_indices] frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8 frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])] return frames def extract_frame_number(filename): # Extract the numeric part from the filename using regular expressions match = re.search(r'_(\d+).jpg$', filename) return int(match.group(1)) if match else -1 def sort_frames(frame_paths): # Extract filenames from each path and sort by their numeric part return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x))) def read_frames_folder( video_path, num_frames, sample='rand', fix_start=None, client=None, clip=None, min_num_frames=4 ): if 's3://' in video_path: image_list = sort_frames(client.list(video_path)) frames = [] for image in image_list: fp = os.path.join(video_path, image) frame = Image.open(io.BytesIO(client.get(fp))) frames.append(frame) else: image_list = sort_frames(list(os.listdir(video_path))) frames = [] for image in image_list: fp = os.path.join(video_path, image) frame = Image.open(fp).convert('RGB') frames.append(frame) vlen = len(frames) t_num_frames = np.random.randint(min_num_frames, num_frames + 1) if vlen > t_num_frames: frame_indices = get_frame_indices( t_num_frames, vlen, sample=sample, fix_start=fix_start ) frames = [frames[i] for i in frame_indices] return frames class WeightedConcatDataset(ConcatDataset): def __init__(self, datasets, weights): super().__init__(datasets) self.weights = torch.DoubleTensor(weights) self.total_size = sum(len(d) for d in datasets) self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) def __iter__(self): return iter(self.sampler) def __len__(self): return self.total_size def pil_loader(img_str): buff = io.BytesIO(img_str) img = Image.open(buff) return img.convert('RGB') class TCSLoader(object): def __init__(self, conf_path, sc_config_key='sensecore'): print(f'[TCSLoader] config_path: {conf_path}') print('--> before Client(conf_path)') self.client = Client(conf_path) self.sc_config_key = sc_config_key print('--> after Client(conf_path)') def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None): if image_type == 'image': img_value_str = self.client.get(fn) img = pil_loader(img_value_str) return img elif image_type == 'video': if fn.endswith('/'): frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, client=self.client, sample=sample) elif fn.endswith('.gif'): frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, client=self.client, sample=sample) else: frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, client=self.client, sample=sample, clip=clip) return frames def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def simulate_jpeg_degradation(quality): def jpeg_degrade(img): with io.BytesIO() as output: img.convert('RGB').save(output, format='JPEG', quality=quality) output.seek(0) # Move the reading cursor to the start of the stream img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory return img_jpeg return jpeg_degrade # Define the JPEG compression quality range, pre-create all JPEG compression functions qualities = list(range(75, 101)) jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): if normalize_type == 'imagenet': MEAN, STD = IMAGENET_MEAN, IMAGENET_STD elif normalize_type == 'clip': MEAN, STD = CLIP_MEAN, CLIP_STD elif normalize_type == 'siglip': MEAN, STD = SIGLIP_MEAN, SIGLIP_STD else: raise NotImplementedError if is_train: # use data augumentation transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) else: if pad2square is False: # now we use this transform function by default transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) else: transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def preprocess( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO # Mask targets. Only compute loss on the assistant outputs. sep = conv.sep + conv.roles[1] + ': ' for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) turns = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_TOKEN_ID for i, turn in enumerate(turns): if turn == '': break turn_len = len(tokenizer(turn).input_ids) parts = turn.split(sep) if len(parts) != 2: break parts[0] += sep # "-2" is hardcoded for the Llama tokenizer to make the offset correct. instruction_len = len(tokenizer(parts[0]).input_ids) - 2 if i != 0 and not tokenizer.legacy: # The legacy and non-legacy modes handle special tokens differently instruction_len -= 1 # Ignore the user instructions target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID cur_len += turn_len if i != 0 and not tokenizer.legacy: # The legacy and non-legacy modes handle special tokens differently cur_len -= 1 target[cur_len:] = IGNORE_TOKEN_ID if False: # Inspect and check the correctness of masking z = target.clone() z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) logger.info(tokenizer.decode(z)) exit() if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print( f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' ) sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def preprocess_mpt( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() # Mask targets. Only compute loss on the assistant outputs. sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) turns = conversation.split(conv.sep) re_turns = [conv.sep.join(turns[:3])] # system + user + gpt for conv_idx in range(3, len(turns), 2): re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_TOKEN_ID for i, turn in enumerate(re_turns): if turn == '': break turn_len = len(tokenizer(turn).input_ids) + 1 parts = turn.split(sep) if len(parts) != 2: break parts[0] += sep instruction_len = len(tokenizer(parts[0]).input_ids) # Ignore the user instructions target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) cur_len += turn_len target[cur_len:] = IGNORE_TOKEN_ID if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print( f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' ) sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def preprocess_phi3( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations tokenizer.padding_side = 'right' input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() # Mask targets. Only compute loss on the assistant outputs. sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|> for conversation, target in zip(conversations, targets): total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) turns = conversation.split(conv.sep) re_turns = [conv.sep.join(turns[:3])] # system + user + gpt for conv_idx in range(3, len(turns), 2): re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt cur_len = 1 target[:cur_len] = IGNORE_TOKEN_ID endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') target[target == endoftext_id] = IGNORE_TOKEN_ID for i, turn in enumerate(re_turns): if turn == '': break if i == 0: turn_len = len(tokenizer(turn).input_ids) else: turn_len = len(tokenizer(turn).input_ids) - 1 parts = turn.split(sep) if len(parts) != 2: break parts[0] += sep if i == 0: instruction_len = len(tokenizer(parts[0]).input_ids) - 1 else: instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # Ignore the user instructions target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) cur_len += turn_len target[cur_len:] = IGNORE_TOKEN_ID if False: # Inspect and check the correctness of masking z = target.clone() z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) print(repr(tokenizer.decode(z))) if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print( f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' ) sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def preprocess_internlm( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' sentence['value'] = sentence['value'].strip() conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id cur_len = 1 target[:cur_len] = IGNORE_TOKEN_ID # parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n info = parts[0] + conv.roles[1] temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的 target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID cur_len = cur_len + temp_len for index in range(1, len(parts) - 1): info = parts[index] part1, part2 = info.split(conv.roles[0]) temp_len = len(tokenizer(part1).input_ids) - 1 cur_len = cur_len + temp_len part = conv.roles[0] + part2 + conv.roles[1] temp_len = len(tokenizer(part).input_ids) - 1 target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID cur_len = cur_len + temp_len last_info = parts[-1] temp_len = len(tokenizer(last_info).input_ids) - 1 cur_len = cur_len + temp_len target[cur_len:] = IGNORE_TOKEN_ID if False: # Inspect and check the correctness of masking z = target.clone() z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) print(repr(tokenizer.decode(z))) if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def preprocess_internvl2_5( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: assert len(sources) == 1, 'process only the first conversations' conversations = sources[0] if conversations[0]['from'] == 'system': system_prompt = conversations[0]['value'] conversations = conversations[1:] # remove system prompt else: conv = get_conv_template(template_name) system_prompt = conv.system_message # system_prompt = None if not text_only: new_conversations = [] current_image_idx = 0 for conversation in conversations: if conversation['from'] == 'human': image_cnt = conversation['value'].count('') for i in range(image_cnt): if current_image_idx == num_image: break image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}' conversation['value'] = conversation['value'].replace('', image_tokens, 1) current_image_idx += 1 new_conversations.append(conversation) conversations = new_conversations assert current_image_idx == num_image, f'{current_image_idx} != {num_image}' batches, roles = [], [] if system_prompt is not None: batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n') roles.append('system') for conversation in conversations: if conversation['from'] == 'human': batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n') roles.append('human') elif conversation['from'] == 'gpt': batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n') roles.append('gpt') else: raise NotImplementedError add_bos_token = getattr(tokenizer, 'add_bos_token', False) if add_bos_token: # for InternLM series batches[0] = tokenizer.bos_token + batches[0] # Tokenize conversations input_ids = tokenizer( batches, return_tensors='np', padding=False, max_length=tokenizer.model_max_length, truncation=False, ).input_ids if add_bos_token: # for InternLM series input_ids = [item[1:] for item in input_ids] final_input_ids, final_targets = [], [] ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0] ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0] for role, input_id in zip(roles, input_ids): final_input_ids.append(input_id) if role == 'system' or role == 'human': final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) # ignore elif role == 'gpt': target = input_id.copy() target[:ignore_len] = IGNORE_TOKEN_ID # ignore loss for `<|im_start|>assistant\n` target[-1:] = IGNORE_TOKEN_ID # ignore loss for `\n` final_targets.append(target) else: raise NotImplementedError input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length] targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length] padding = False if group_by_length or use_packed_ds else True if padding: current_length = input_ids.size(0) padding_length = tokenizer.model_max_length - current_length input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id) targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID) input_ids = input_ids.unsqueeze(0) targets = targets.unsqueeze(0) return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_ratio=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) if return_ratio: return processed_images, target_aspect_ratio return processed_images def dynamic_preprocess_mask(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): # import pdb length, orig_height, orig_width = image.shape aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # print(target_aspect_ratio) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image tensor_images = image.unsqueeze(1) # 添加一个维度作为单通道 # pdb.set_trace() resized_images = F.interpolate(tensor_images, size=(target_height, target_width), mode='bilinear', align_corners=False) #(1792,1344) resized_images = resized_images > 0 # print(resized_images.shape) # 然后像 PIL 那样裁剪图像块 processed_images = [] for i in range(blocks): top = (i // (target_width // image_size)) * image_size left = (i % (target_width // image_size)) * image_size bottom = top + image_size right = left + image_size # 使用 tensor 切片进行裁剪 split_img = resized_images[..., top:bottom, left:right] # 这里使用...来保持通道这一维度 processed_images.append(split_img) # plt.imshow(split_img.sum(0).squeeze()) # plt.savefig(f'/workdir/guantongkun/12490719/eef5a3b245897c9f4335463fb12fed35/work_dirs/{i}_mask.jpg', dpi=600) # pdb.set_trace() # 最后,如果您需要,可以对处理过的图像list进行任何后续操作 # 例如,convert回通道为最后维度的形式,如果是单通道的话 processed_images = [img.squeeze(1) for img in processed_images] assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = F.interpolate(tensor_images, size=(image_size, image_size), mode='bilinear', align_corners=False).squeeze(1) thumbnail_img = thumbnail_img > 0 # Image.fromarray(thumbnail_img.cpu().numpy().astype(np.uint8)) processed_images.append(thumbnail_img) return processed_images