TongkunGuan's picture
Upload 94 files
841bef5 verified
raw
history blame
36.3 kB
# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import io
import matplotlib.pyplot as plt
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import os
import random
import re
from collections import Counter
from typing import Dict
import cv2
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import transformers
from decord import VideoReader
from internvl.conversation import get_conv_template
from PIL import Image
from torch.utils.data import ConcatDataset, WeightedRandomSampler
from torchvision.transforms.functional import InterpolationMode
from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
SIGLIP_MEAN, SIGLIP_STD)
try:
from petrel_client.client import Client
from petrel_client.common.config import Config
except ImportError as E:
print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.')
import sys
def calculate_ngram_repetition(text, n):
words = text.split()
ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
ngram_counts = Counter(ngrams)
total_ngrams = len(ngrams)
repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1)
return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0
def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10):
for conversation in conversations:
if conversation['from'] == 'gpt':
model_answer = conversation['value']
repeat_ratio = calculate_ngram_repetition(model_answer, ngram)
if repeat_ratio > repeat_threshold:
raise Exception
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ['rand', 'middle']: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_gif(
video_path, num_frames, sample='rand', fix_start=None,
client=None, min_num_frames=4
):
if 's3://' in video_path:
video_bytes = client.get(video_path)
gif = imageio.get_reader(io.BytesIO(video_bytes))
else:
gif = imageio.get_reader(video_path)
vlen = len(gif)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start
)
frames = []
for index, frame in enumerate(gif):
if index in frame_indices:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8)
frame = Image.fromarray(frame)
frames.append(frame)
return frames
def read_frames_decord(
video_path, num_frames, sample='rand', fix_start=None,
client=None, clip=None, min_num_frames=4
):
if 's3://' in video_path:
video_bytes = client.get(video_path)
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
else:
video_reader = VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
if clip:
start, end = clip
duration = end - start
vlen = int(duration * fps)
start_index = int(start * fps)
# t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps
)
if clip:
frame_indices = [f + start_index for f in frame_indices]
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
return frames
def extract_frame_number(filename):
# Extract the numeric part from the filename using regular expressions
match = re.search(r'_(\d+).jpg$', filename)
return int(match.group(1)) if match else -1
def sort_frames(frame_paths):
# Extract filenames from each path and sort by their numeric part
return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
def read_frames_folder(
video_path, num_frames, sample='rand', fix_start=None,
client=None, clip=None, min_num_frames=4
):
if 's3://' in video_path:
image_list = sort_frames(client.list(video_path))
frames = []
for image in image_list:
fp = os.path.join(video_path, image)
frame = Image.open(io.BytesIO(client.get(fp)))
frames.append(frame)
else:
image_list = sort_frames(list(os.listdir(video_path)))
frames = []
for image in image_list:
fp = os.path.join(video_path, image)
frame = Image.open(fp).convert('RGB')
frames.append(frame)
vlen = len(frames)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
if vlen > t_num_frames:
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start
)
frames = [frames[i] for i in frame_indices]
return frames
class WeightedConcatDataset(ConcatDataset):
def __init__(self, datasets, weights):
super().__init__(datasets)
self.weights = torch.DoubleTensor(weights)
self.total_size = sum(len(d) for d in datasets)
self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
def __iter__(self):
return iter(self.sampler)
def __len__(self):
return self.total_size
def pil_loader(img_str):
buff = io.BytesIO(img_str)
img = Image.open(buff)
return img.convert('RGB')
class TCSLoader(object):
def __init__(self, conf_path, sc_config_key='sensecore'):
print(f'[TCSLoader] config_path: {conf_path}')
print('--> before Client(conf_path)')
self.client = Client(conf_path)
self.sc_config_key = sc_config_key
print('--> after Client(conf_path)')
def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None):
if image_type == 'image':
img_value_str = self.client.get(fn)
img = pil_loader(img_value_str)
return img
elif image_type == 'video':
if fn.endswith('/'):
frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
client=self.client, sample=sample)
elif fn.endswith('.gif'):
frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
client=self.client, sample=sample)
else:
frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
client=self.client, sample=sample, clip=clip)
return frames
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def simulate_jpeg_degradation(quality):
def jpeg_degrade(img):
with io.BytesIO() as output:
img.convert('RGB').save(output, format='JPEG', quality=quality)
output.seek(0) # Move the reading cursor to the start of the stream
img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
return img_jpeg
return jpeg_degrade
# Define the JPEG compression quality range, pre-create all JPEG compression functions
qualities = list(range(75, 101))
jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
if normalize_type == 'imagenet':
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
elif normalize_type == 'clip':
MEAN, STD = CLIP_MEAN, CLIP_STD
elif normalize_type == 'siglip':
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
else:
raise NotImplementedError
if is_train: # use data augumentation
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
else:
if pad2square is False: # now we use this transform function by default
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
else:
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def preprocess(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ': '
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
instruction_len -= 1
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
cur_len -= 1
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
logger.info(tokenizer.decode(z))
exit()
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_mpt(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids) + 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids)
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_phi3(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
tokenizer.padding_side = 'right'
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
target[target == endoftext_id] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
if i == 0:
turn_len = len(tokenizer(turn).input_ids)
else:
turn_len = len(tokenizer(turn).input_ids) - 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if i == 0:
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
else:
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_internlm(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
sentence['value'] = sentence['value'].strip()
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID # <s>
parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
info = parts[0] + conv.roles[1]
temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s>
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
for index in range(1, len(parts) - 1):
info = parts[index]
part1, part2 = info.split(conv.roles[0])
temp_len = len(tokenizer(part1).input_ids) - 1
cur_len = cur_len + temp_len
part = conv.roles[0] + part2 + conv.roles[1]
temp_len = len(tokenizer(part).input_ids) - 1
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
last_info = parts[-1]
temp_len = len(tokenizer(last_info).input_ids) - 1
cur_len = cur_len + temp_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_internvl2_5(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
assert len(sources) == 1, 'process only the first conversations'
conversations = sources[0]
if conversations[0]['from'] == 'system':
system_prompt = conversations[0]['value']
conversations = conversations[1:] # remove system prompt
else:
conv = get_conv_template(template_name)
system_prompt = conv.system_message
# system_prompt = None
if not text_only:
new_conversations = []
current_image_idx = 0
for conversation in conversations:
if conversation['from'] == 'human':
image_cnt = conversation['value'].count('<image>')
for i in range(image_cnt):
if current_image_idx == num_image:
break
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}'
conversation['value'] = conversation['value'].replace('<image>', image_tokens, 1)
current_image_idx += 1
new_conversations.append(conversation)
conversations = new_conversations
assert current_image_idx == num_image, f'{current_image_idx} != {num_image}'
batches, roles = [], []
if system_prompt is not None:
batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n')
roles.append('system')
for conversation in conversations:
if conversation['from'] == 'human':
batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n')
roles.append('human')
elif conversation['from'] == 'gpt':
batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n')
roles.append('gpt')
else:
raise NotImplementedError
add_bos_token = getattr(tokenizer, 'add_bos_token', False)
if add_bos_token: # for InternLM series
batches[0] = tokenizer.bos_token + batches[0]
# Tokenize conversations
input_ids = tokenizer(
batches,
return_tensors='np',
padding=False,
max_length=tokenizer.model_max_length,
truncation=False,
).input_ids
if add_bos_token: # for InternLM series
input_ids = [item[1:] for item in input_ids]
final_input_ids, final_targets = [], []
ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0]
ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0]
for role, input_id in zip(roles, input_ids):
final_input_ids.append(input_id)
if role == 'system' or role == 'human':
final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) # ignore
elif role == 'gpt':
target = input_id.copy()
target[:ignore_len] = IGNORE_TOKEN_ID # ignore loss for `<|im_start|>assistant\n`
target[-1:] = IGNORE_TOKEN_ID # ignore loss for `\n`
final_targets.append(target)
else:
raise NotImplementedError
input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length]
targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length]
padding = False if group_by_length or use_packed_ds else True
if padding:
current_length = input_ids.size(0)
padding_length = tokenizer.model_max_length - current_length
input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id)
targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID)
input_ids = input_ids.unsqueeze(0)
targets = targets.unsqueeze(0)
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_ratio=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
if return_ratio:
return processed_images, target_aspect_ratio
return processed_images
def dynamic_preprocess_mask(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
# import pdb
length, orig_height, orig_width = image.shape
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# print(target_aspect_ratio)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
tensor_images = image.unsqueeze(1) # 添加一个维度作为单通道
# pdb.set_trace()
resized_images = F.interpolate(tensor_images, size=(target_height, target_width), mode='bilinear', align_corners=False) #(1792,1344)
resized_images = resized_images > 0
# print(resized_images.shape)
# 然后像 PIL 那样裁剪图像块
processed_images = []
for i in range(blocks):
top = (i // (target_width // image_size)) * image_size
left = (i % (target_width // image_size)) * image_size
bottom = top + image_size
right = left + image_size
# 使用 tensor 切片进行裁剪
split_img = resized_images[..., top:bottom, left:right] # 这里使用...来保持通道这一维度
processed_images.append(split_img)
# plt.imshow(split_img.sum(0).squeeze())
# plt.savefig(f'/workdir/guantongkun/12490719/eef5a3b245897c9f4335463fb12fed35/work_dirs/{i}_mask.jpg', dpi=600)
# pdb.set_trace()
# 最后,如果您需要,可以对处理过的图像list进行任何后续操作
# 例如,convert回通道为最后维度的形式,如果是单通道的话
processed_images = [img.squeeze(1) for img in processed_images]
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = F.interpolate(tensor_images, size=(image_size, image_size), mode='bilinear', align_corners=False).squeeze(1)
thumbnail_img = thumbnail_img > 0
# Image.fromarray(thumbnail_img.cpu().numpy().astype(np.uint8))
processed_images.append(thumbnail_img)
return processed_images