Spaces:
Running
on
Zero
Running
on
Zero
# -------------------------------------------------------- | |
# InternVL | |
# Copyright (c) 2024 OpenGVLab | |
# Licensed under The MIT License [see LICENSE for details] | |
# -------------------------------------------------------- | |
import io | |
import matplotlib.pyplot as plt | |
from transformers.trainer_pt_utils import LabelSmoother | |
IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
import os | |
import random | |
import re | |
from collections import Counter | |
from typing import Dict | |
import cv2 | |
import imageio | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
import transformers | |
from decord import VideoReader | |
from internvl.conversation import get_conv_template | |
from PIL import Image | |
from torch.utils.data import ConcatDataset, WeightedRandomSampler | |
from torchvision.transforms.functional import InterpolationMode | |
from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD, | |
IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, | |
SIGLIP_MEAN, SIGLIP_STD) | |
try: | |
from petrel_client.client import Client | |
from petrel_client.common.config import Config | |
except ImportError as E: | |
print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') | |
import sys | |
def calculate_ngram_repetition(text, n): | |
words = text.split() | |
ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] | |
ngram_counts = Counter(ngrams) | |
total_ngrams = len(ngrams) | |
repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1) | |
return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0 | |
def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10): | |
for conversation in conversations: | |
if conversation['from'] == 'gpt': | |
model_answer = conversation['value'] | |
repeat_ratio = calculate_ngram_repetition(model_answer, ngram) | |
if repeat_ratio > repeat_threshold: | |
raise Exception | |
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): | |
if sample in ['rand', 'middle']: # uniform sampling | |
acc_samples = min(num_frames, vlen) | |
# split the video into `acc_samples` intervals, and sample from each interval. | |
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) | |
ranges = [] | |
for idx, interv in enumerate(intervals[:-1]): | |
ranges.append((interv, intervals[idx + 1] - 1)) | |
if sample == 'rand': | |
try: | |
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] | |
except: | |
frame_indices = np.random.permutation(vlen)[:acc_samples] | |
frame_indices.sort() | |
frame_indices = list(frame_indices) | |
elif fix_start is not None: | |
frame_indices = [x[0] + fix_start for x in ranges] | |
elif sample == 'middle': | |
frame_indices = [(x[0] + x[1]) // 2 for x in ranges] | |
else: | |
raise NotImplementedError | |
if len(frame_indices) < num_frames: # padded with last frame | |
padded_frame_indices = [frame_indices[-1]] * num_frames | |
padded_frame_indices[:len(frame_indices)] = frame_indices | |
frame_indices = padded_frame_indices | |
elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps | |
output_fps = float(sample[3:]) | |
duration = float(vlen) / input_fps | |
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents | |
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) | |
frame_indices = np.around(frame_seconds * input_fps).astype(int) | |
frame_indices = [e for e in frame_indices if e < vlen] | |
if max_num_frames > 0 and len(frame_indices) > max_num_frames: | |
frame_indices = frame_indices[:max_num_frames] | |
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) | |
else: | |
raise ValueError | |
return frame_indices | |
def read_frames_gif( | |
video_path, num_frames, sample='rand', fix_start=None, | |
client=None, min_num_frames=4 | |
): | |
if 's3://' in video_path: | |
video_bytes = client.get(video_path) | |
gif = imageio.get_reader(io.BytesIO(video_bytes)) | |
else: | |
gif = imageio.get_reader(video_path) | |
vlen = len(gif) | |
t_num_frames = np.random.randint(min_num_frames, num_frames + 1) | |
frame_indices = get_frame_indices( | |
t_num_frames, vlen, sample=sample, fix_start=fix_start | |
) | |
frames = [] | |
for index, frame in enumerate(gif): | |
if index in frame_indices: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8) | |
frame = Image.fromarray(frame) | |
frames.append(frame) | |
return frames | |
def read_frames_decord( | |
video_path, num_frames, sample='rand', fix_start=None, | |
client=None, clip=None, min_num_frames=4 | |
): | |
if 's3://' in video_path: | |
video_bytes = client.get(video_path) | |
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) | |
else: | |
video_reader = VideoReader(video_path, num_threads=1) | |
vlen = len(video_reader) | |
fps = video_reader.get_avg_fps() | |
duration = vlen / float(fps) | |
if clip: | |
start, end = clip | |
duration = end - start | |
vlen = int(duration * fps) | |
start_index = int(start * fps) | |
# t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames) | |
t_num_frames = np.random.randint(min_num_frames, num_frames + 1) | |
frame_indices = get_frame_indices( | |
t_num_frames, vlen, sample=sample, fix_start=fix_start, | |
input_fps=fps | |
) | |
if clip: | |
frame_indices = [f + start_index for f in frame_indices] | |
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8 | |
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])] | |
return frames | |
def extract_frame_number(filename): | |
# Extract the numeric part from the filename using regular expressions | |
match = re.search(r'_(\d+).jpg$', filename) | |
return int(match.group(1)) if match else -1 | |
def sort_frames(frame_paths): | |
# Extract filenames from each path and sort by their numeric part | |
return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x))) | |
def read_frames_folder( | |
video_path, num_frames, sample='rand', fix_start=None, | |
client=None, clip=None, min_num_frames=4 | |
): | |
if 's3://' in video_path: | |
image_list = sort_frames(client.list(video_path)) | |
frames = [] | |
for image in image_list: | |
fp = os.path.join(video_path, image) | |
frame = Image.open(io.BytesIO(client.get(fp))) | |
frames.append(frame) | |
else: | |
image_list = sort_frames(list(os.listdir(video_path))) | |
frames = [] | |
for image in image_list: | |
fp = os.path.join(video_path, image) | |
frame = Image.open(fp).convert('RGB') | |
frames.append(frame) | |
vlen = len(frames) | |
t_num_frames = np.random.randint(min_num_frames, num_frames + 1) | |
if vlen > t_num_frames: | |
frame_indices = get_frame_indices( | |
t_num_frames, vlen, sample=sample, fix_start=fix_start | |
) | |
frames = [frames[i] for i in frame_indices] | |
return frames | |
class WeightedConcatDataset(ConcatDataset): | |
def __init__(self, datasets, weights): | |
super().__init__(datasets) | |
self.weights = torch.DoubleTensor(weights) | |
self.total_size = sum(len(d) for d in datasets) | |
self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) | |
def __iter__(self): | |
return iter(self.sampler) | |
def __len__(self): | |
return self.total_size | |
def pil_loader(img_str): | |
buff = io.BytesIO(img_str) | |
img = Image.open(buff) | |
return img.convert('RGB') | |
class TCSLoader(object): | |
def __init__(self, conf_path, sc_config_key='sensecore'): | |
print(f'[TCSLoader] config_path: {conf_path}') | |
print('--> before Client(conf_path)') | |
self.client = Client(conf_path) | |
self.sc_config_key = sc_config_key | |
print('--> after Client(conf_path)') | |
def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None): | |
if image_type == 'image': | |
img_value_str = self.client.get(fn) | |
img = pil_loader(img_value_str) | |
return img | |
elif image_type == 'video': | |
if fn.endswith('/'): | |
frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, | |
client=self.client, sample=sample) | |
elif fn.endswith('.gif'): | |
frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, | |
client=self.client, sample=sample) | |
else: | |
frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, | |
client=self.client, sample=sample, clip=clip) | |
return frames | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
def simulate_jpeg_degradation(quality): | |
def jpeg_degrade(img): | |
with io.BytesIO() as output: | |
img.convert('RGB').save(output, format='JPEG', quality=quality) | |
output.seek(0) # Move the reading cursor to the start of the stream | |
img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory | |
return img_jpeg | |
return jpeg_degrade | |
# Define the JPEG compression quality range, pre-create all JPEG compression functions | |
qualities = list(range(75, 101)) | |
jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} | |
def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): | |
if normalize_type == 'imagenet': | |
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD | |
elif normalize_type == 'clip': | |
MEAN, STD = CLIP_MEAN, CLIP_STD | |
elif normalize_type == 'siglip': | |
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD | |
else: | |
raise NotImplementedError | |
if is_train: # use data augumentation | |
transform = T.Compose([ | |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), | |
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
T.ToTensor(), | |
T.Normalize(mean=MEAN, std=STD) | |
]) | |
else: | |
if pad2square is False: # now we use this transform function by default | |
transform = T.Compose([ | |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
T.ToTensor(), | |
T.Normalize(mean=MEAN, std=STD) | |
]) | |
else: | |
transform = T.Compose([ | |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), | |
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
T.ToTensor(), | |
T.Normalize(mean=MEAN, std=STD) | |
]) | |
return transform | |
def preprocess( | |
template_name, | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
num_image_token_list: list, | |
text_only: bool = False, | |
group_by_length: bool = False, | |
use_packed_ds: bool = False, | |
ds_name: str = None, | |
num_image: int = 1 | |
) -> Dict: | |
conv = get_conv_template(template_name) | |
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]['from']] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence['from']] | |
assert role == conv.roles[j % 2], f'{i}' | |
conv.append_message(role, sentence['value']) | |
conversations.append(conv.get_prompt()) | |
if not text_only: | |
new_conversations = [] | |
for conversation in conversations: | |
for i in range(num_image): | |
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
conversation = conversation.replace('<image>', image_tokens, 1) | |
new_conversations.append(conversation) | |
conversations = new_conversations | |
# Tokenize conversations | |
input_ids = tokenizer( | |
conversations, | |
return_tensors='pt', | |
padding=False if group_by_length or use_packed_ds else 'max_length', | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
targets = input_ids.clone() | |
# assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO | |
# Mask targets. Only compute loss on the assistant outputs. | |
sep = conv.sep + conv.roles[1] + ': ' | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
turns = conversation.split(conv.sep2) | |
cur_len = 1 | |
target[:cur_len] = IGNORE_TOKEN_ID | |
for i, turn in enumerate(turns): | |
if turn == '': | |
break | |
turn_len = len(tokenizer(turn).input_ids) | |
parts = turn.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
# "-2" is hardcoded for the Llama tokenizer to make the offset correct. | |
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
if i != 0 and not tokenizer.legacy: | |
# The legacy and non-legacy modes handle special tokens differently | |
instruction_len -= 1 | |
# Ignore the user instructions | |
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID | |
cur_len += turn_len | |
if i != 0 and not tokenizer.legacy: | |
# The legacy and non-legacy modes handle special tokens differently | |
cur_len -= 1 | |
target[cur_len:] = IGNORE_TOKEN_ID | |
if False: # Inspect and check the correctness of masking | |
z = target.clone() | |
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) | |
logger.info(tokenizer.decode(z)) | |
exit() | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_TOKEN_ID | |
print( | |
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' | |
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' | |
) | |
sys.stdout.flush() | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
) | |
def preprocess_mpt( | |
template_name, | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
num_image_token_list: list, | |
text_only: bool = False, | |
group_by_length: bool = False, | |
use_packed_ds: bool = False, | |
ds_name: str = None, | |
num_image: int = 1 | |
) -> Dict: | |
conv = get_conv_template(template_name) | |
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]['from']] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence['from']] | |
assert role == conv.roles[j % 2], f'{i}' | |
conv.append_message(role, sentence['value']) | |
conversations.append(conv.get_prompt()) | |
if not text_only: | |
new_conversations = [] | |
for conversation in conversations: | |
for i in range(num_image): | |
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
conversation = conversation.replace('<image>', image_tokens, 1) | |
new_conversations.append(conversation) | |
conversations = new_conversations | |
# Tokenize conversations | |
input_ids = tokenizer( | |
conversations, | |
return_tensors='pt', | |
padding=False if group_by_length or use_packed_ds else 'max_length', | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
targets = input_ids.clone() | |
# Mask targets. Only compute loss on the assistant outputs. | |
sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
turns = conversation.split(conv.sep) | |
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt | |
for conv_idx in range(3, len(turns), 2): | |
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt | |
cur_len = 0 | |
target[:cur_len] = IGNORE_TOKEN_ID | |
for i, turn in enumerate(re_turns): | |
if turn == '': | |
break | |
turn_len = len(tokenizer(turn).input_ids) + 1 | |
parts = turn.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
instruction_len = len(tokenizer(parts[0]).input_ids) | |
# Ignore the user instructions | |
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID | |
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) | |
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) | |
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) | |
cur_len += turn_len | |
target[cur_len:] = IGNORE_TOKEN_ID | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_TOKEN_ID | |
print( | |
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' | |
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' | |
) | |
sys.stdout.flush() | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
) | |
def preprocess_phi3( | |
template_name, | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
num_image_token_list: list, | |
text_only: bool = False, | |
group_by_length: bool = False, | |
use_packed_ds: bool = False, | |
ds_name: str = None, | |
num_image: int = 1 | |
) -> Dict: | |
conv = get_conv_template(template_name) | |
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]['from']] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence['from']] | |
assert role == conv.roles[j % 2], f'{i}' | |
conv.append_message(role, sentence['value']) | |
conversations.append(conv.get_prompt()) | |
if not text_only: | |
new_conversations = [] | |
for conversation in conversations: | |
for i in range(num_image): | |
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
conversation = conversation.replace('<image>', image_tokens, 1) | |
new_conversations.append(conversation) | |
conversations = new_conversations | |
# Tokenize conversations | |
tokenizer.padding_side = 'right' | |
input_ids = tokenizer( | |
conversations, | |
return_tensors='pt', | |
padding=False if group_by_length or use_packed_ds else 'max_length', | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
targets = input_ids.clone() | |
# Mask targets. Only compute loss on the assistant outputs. | |
sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|> | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) | |
turns = conversation.split(conv.sep) | |
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt | |
for conv_idx in range(3, len(turns), 2): | |
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt | |
cur_len = 1 | |
target[:cur_len] = IGNORE_TOKEN_ID | |
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') | |
target[target == endoftext_id] = IGNORE_TOKEN_ID | |
for i, turn in enumerate(re_turns): | |
if turn == '': | |
break | |
if i == 0: | |
turn_len = len(tokenizer(turn).input_ids) | |
else: | |
turn_len = len(tokenizer(turn).input_ids) - 1 | |
parts = turn.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
if i == 0: | |
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 | |
else: | |
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
# Ignore the user instructions | |
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID | |
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) | |
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) | |
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) | |
cur_len += turn_len | |
target[cur_len:] = IGNORE_TOKEN_ID | |
if False: # Inspect and check the correctness of masking | |
z = target.clone() | |
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) | |
print(repr(tokenizer.decode(z))) | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_TOKEN_ID | |
print( | |
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' | |
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' | |
) | |
sys.stdout.flush() | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
) | |
def preprocess_internlm( | |
template_name, | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
num_image_token_list: list, | |
text_only: bool = False, | |
group_by_length: bool = False, | |
use_packed_ds: bool = False, | |
ds_name: str = None, | |
num_image: int = 1 | |
) -> Dict: | |
conv = get_conv_template(template_name) | |
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]['from']] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence['from']] | |
assert role == conv.roles[j % 2], f'{i}' | |
sentence['value'] = sentence['value'].strip() | |
conv.append_message(role, sentence['value']) | |
conversations.append(conv.get_prompt()) | |
if not text_only: | |
new_conversations = [] | |
for conversation in conversations: | |
for i in range(num_image): | |
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
conversation = conversation.replace('<image>', image_tokens, 1) | |
new_conversations.append(conversation) | |
conversations = new_conversations | |
# Tokenize conversations | |
input_ids = tokenizer( | |
conversations, | |
return_tensors='pt', | |
padding=False if group_by_length or use_packed_ds else 'max_length', | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
targets = input_ids.clone() | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id | |
cur_len = 1 | |
target[:cur_len] = IGNORE_TOKEN_ID # <s> | |
parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n | |
info = parts[0] + conv.roles[1] | |
temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s> | |
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID | |
cur_len = cur_len + temp_len | |
for index in range(1, len(parts) - 1): | |
info = parts[index] | |
part1, part2 = info.split(conv.roles[0]) | |
temp_len = len(tokenizer(part1).input_ids) - 1 | |
cur_len = cur_len + temp_len | |
part = conv.roles[0] + part2 + conv.roles[1] | |
temp_len = len(tokenizer(part).input_ids) - 1 | |
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID | |
cur_len = cur_len + temp_len | |
last_info = parts[-1] | |
temp_len = len(tokenizer(last_info).input_ids) - 1 | |
cur_len = cur_len + temp_len | |
target[cur_len:] = IGNORE_TOKEN_ID | |
if False: # Inspect and check the correctness of masking | |
z = target.clone() | |
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) | |
print(repr(tokenizer.decode(z))) | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_TOKEN_ID | |
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') | |
sys.stdout.flush() | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
) | |
def preprocess_internvl2_5( | |
template_name, | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
num_image_token_list: list, | |
text_only: bool = False, | |
group_by_length: bool = False, | |
use_packed_ds: bool = False, | |
ds_name: str = None, | |
num_image: int = 1 | |
) -> Dict: | |
assert len(sources) == 1, 'process only the first conversations' | |
conversations = sources[0] | |
if conversations[0]['from'] == 'system': | |
system_prompt = conversations[0]['value'] | |
conversations = conversations[1:] # remove system prompt | |
else: | |
conv = get_conv_template(template_name) | |
system_prompt = conv.system_message | |
# system_prompt = None | |
if not text_only: | |
new_conversations = [] | |
current_image_idx = 0 | |
for conversation in conversations: | |
if conversation['from'] == 'human': | |
image_cnt = conversation['value'].count('<image>') | |
for i in range(image_cnt): | |
if current_image_idx == num_image: | |
break | |
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}' | |
conversation['value'] = conversation['value'].replace('<image>', image_tokens, 1) | |
current_image_idx += 1 | |
new_conversations.append(conversation) | |
conversations = new_conversations | |
assert current_image_idx == num_image, f'{current_image_idx} != {num_image}' | |
batches, roles = [], [] | |
if system_prompt is not None: | |
batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n') | |
roles.append('system') | |
for conversation in conversations: | |
if conversation['from'] == 'human': | |
batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n') | |
roles.append('human') | |
elif conversation['from'] == 'gpt': | |
batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n') | |
roles.append('gpt') | |
else: | |
raise NotImplementedError | |
add_bos_token = getattr(tokenizer, 'add_bos_token', False) | |
if add_bos_token: # for InternLM series | |
batches[0] = tokenizer.bos_token + batches[0] | |
# Tokenize conversations | |
input_ids = tokenizer( | |
batches, | |
return_tensors='np', | |
padding=False, | |
max_length=tokenizer.model_max_length, | |
truncation=False, | |
).input_ids | |
if add_bos_token: # for InternLM series | |
input_ids = [item[1:] for item in input_ids] | |
final_input_ids, final_targets = [], [] | |
ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0] | |
ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0] | |
for role, input_id in zip(roles, input_ids): | |
final_input_ids.append(input_id) | |
if role == 'system' or role == 'human': | |
final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) # ignore | |
elif role == 'gpt': | |
target = input_id.copy() | |
target[:ignore_len] = IGNORE_TOKEN_ID # ignore loss for `<|im_start|>assistant\n` | |
target[-1:] = IGNORE_TOKEN_ID # ignore loss for `\n` | |
final_targets.append(target) | |
else: | |
raise NotImplementedError | |
input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length] | |
targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length] | |
padding = False if group_by_length or use_packed_ds else True | |
if padding: | |
current_length = input_ids.size(0) | |
padding_length = tokenizer.model_max_length - current_length | |
input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id) | |
targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID) | |
input_ids = input_ids.unsqueeze(0) | |
targets = targets.unsqueeze(0) | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
) | |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
best_ratio_diff = float('inf') | |
best_ratio = (1, 1) | |
area = width * height | |
for ratio in target_ratios: | |
target_aspect_ratio = ratio[0] / ratio[1] | |
ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
if ratio_diff < best_ratio_diff: | |
best_ratio_diff = ratio_diff | |
best_ratio = ratio | |
elif ratio_diff == best_ratio_diff: | |
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
best_ratio = ratio | |
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') | |
return best_ratio | |
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_ratio=False): | |
orig_width, orig_height = image.size | |
aspect_ratio = orig_width / orig_height | |
# calculate the existing image aspect ratio | |
target_ratios = set( | |
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
i * j <= max_num and i * j >= min_num) | |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
# find the closest aspect ratio to the target | |
target_aspect_ratio = find_closest_aspect_ratio( | |
aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
# calculate the target width and height | |
target_width = image_size * target_aspect_ratio[0] | |
target_height = image_size * target_aspect_ratio[1] | |
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
# resize the image | |
resized_img = image.resize((target_width, target_height)) | |
processed_images = [] | |
for i in range(blocks): | |
box = ( | |
(i % (target_width // image_size)) * image_size, | |
(i // (target_width // image_size)) * image_size, | |
((i % (target_width // image_size)) + 1) * image_size, | |
((i // (target_width // image_size)) + 1) * image_size | |
) | |
# split the image | |
split_img = resized_img.crop(box) | |
processed_images.append(split_img) | |
assert len(processed_images) == blocks | |
if use_thumbnail and len(processed_images) != 1: | |
thumbnail_img = image.resize((image_size, image_size)) | |
processed_images.append(thumbnail_img) | |
if return_ratio: | |
return processed_images, target_aspect_ratio | |
return processed_images | |
def dynamic_preprocess_mask(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): | |
# import pdb | |
length, orig_height, orig_width = image.shape | |
aspect_ratio = orig_width / orig_height | |
# calculate the existing image aspect ratio | |
target_ratios = set( | |
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
i * j <= max_num and i * j >= min_num) | |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
# find the closest aspect ratio to the target | |
target_aspect_ratio = find_closest_aspect_ratio( | |
aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
# print(target_aspect_ratio) | |
# calculate the target width and height | |
target_width = image_size * target_aspect_ratio[0] | |
target_height = image_size * target_aspect_ratio[1] | |
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
# resize the image | |
tensor_images = image.unsqueeze(1) # 添加一个维度作为单通道 | |
# pdb.set_trace() | |
resized_images = F.interpolate(tensor_images, size=(target_height, target_width), mode='bilinear', align_corners=False) #(1792,1344) | |
resized_images = resized_images > 0 | |
# print(resized_images.shape) | |
# 然后像 PIL 那样裁剪图像块 | |
processed_images = [] | |
for i in range(blocks): | |
top = (i // (target_width // image_size)) * image_size | |
left = (i % (target_width // image_size)) * image_size | |
bottom = top + image_size | |
right = left + image_size | |
# 使用 tensor 切片进行裁剪 | |
split_img = resized_images[..., top:bottom, left:right] # 这里使用...来保持通道这一维度 | |
processed_images.append(split_img) | |
# plt.imshow(split_img.sum(0).squeeze()) | |
# plt.savefig(f'/workdir/guantongkun/12490719/eef5a3b245897c9f4335463fb12fed35/work_dirs/{i}_mask.jpg', dpi=600) | |
# pdb.set_trace() | |
# 最后,如果您需要,可以对处理过的图像list进行任何后续操作 | |
# 例如,convert回通道为最后维度的形式,如果是单通道的话 | |
processed_images = [img.squeeze(1) for img in processed_images] | |
assert len(processed_images) == blocks | |
if use_thumbnail and len(processed_images) != 1: | |
thumbnail_img = F.interpolate(tensor_images, size=(image_size, image_size), mode='bilinear', align_corners=False).squeeze(1) | |
thumbnail_img = thumbnail_img > 0 | |
# Image.fromarray(thumbnail_img.cpu().numpy().astype(np.uint8)) | |
processed_images.append(thumbnail_img) | |
return processed_images | |