File size: 19,617 Bytes
d93d2f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 |
"""
finetune Phi-4-multimodal-instruct on an image task
scipy==1.15.1
peft==0.13.2
backoff==2.2.1
transformers==4.47.0
accelerate==1.3.0
"""
import argparse
import json
import os
import tempfile
import zipfile
from pathlib import Path
import torch
from accelerate import Accelerator
from accelerate.utils import gather_object
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
BatchFeature,
Trainer,
TrainingArguments,
)
DEFAULT_INSTSRUCTION = "Answer with the option's letter from the given choices directly."
_IGNORE_INDEX = -100
_TRAIN_SIZE = 8000
_EVAL_SIZE = 500
_MAX_TRAINING_LENGTH = 8192
class PmcVqaTrainDataset(Dataset):
def __init__(self, processor, data_size, instruction=DEFAULT_INSTSRUCTION):
# Download the file
file_path = hf_hub_download(
repo_id='xmcmic/PMC-VQA', # repository name
filename='images_2.zip', # file to download
repo_type='dataset', # specify it's a dataset repo
)
# file_path will be the local path where the file was downloaded
print(f'File downloaded to: {file_path}')
# unzip to temp folder
self.image_folder = Path(tempfile.mkdtemp())
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(self.image_folder)
data_files = {
'train': 'https://huggingface.co/datasets/xmcmic/PMC-VQA/resolve/main/train_2.csv',
}
split = 'train' if data_size is None else f'train[:{data_size}]'
self.annotations = load_dataset('xmcmic/PMC-VQA', data_files=data_files, split=split)
self.processor = processor
self.instruction = instruction
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
"""
{'index': 35,
'Figure_path': 'PMC8253797_Fig4_11.jpg',
'Caption': 'A slightly altered cell . (c-c‴) A highly altered cell as seen from 4 different angles . Note mitochondria/mitochondrial networks (green), Golgi complexes (red), cell nuclei (light blue) and the cell outline (yellow).',
'Question': ' What color is used to label the Golgi complexes in the image?',
'Choice A': ' A: Green ',
'Choice B': ' B: Red ',
'Choice C': ' C: Light blue ',
'Choice D': ' D: Yellow',
'Answer': 'B',
'split': 'train'}
"""
annotation = self.annotations[idx]
image = Image.open(self.image_folder / 'figures' / annotation['Figure_path'])
question = annotation['Question']
choices = [annotation[f'Choice {chr(ord("A") + i)}'] for i in range(4)]
user_message = {
'role': 'user',
'content': '<|image_1|>' + '\n'.join([question] + choices + [self.instruction]),
}
prompt = self.processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True
)
answer = f'{annotation["Answer"]}<|end|><|endoftext|>'
inputs = self.processor(prompt, images=[image], return_tensors='pt')
answer_ids = self.processor.tokenizer(answer, return_tensors='pt').input_ids
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
labels = torch.full_like(input_ids, _IGNORE_INDEX)
labels[:, -answer_ids.shape[1] :] = answer_ids
if input_ids.size(1) > _MAX_TRAINING_LENGTH:
input_ids = input_ids[:, :_MAX_TRAINING_LENGTH]
labels = labels[:, :_MAX_TRAINING_LENGTH]
if torch.all(labels == _IGNORE_INDEX).item():
# workaround to make sure loss compute won't fail
labels[:, -1] = self.processor.tokenizer.eos_token_id
return {
'input_ids': input_ids,
'labels': labels,
'input_image_embeds': inputs.input_image_embeds,
'image_attention_mask': inputs.image_attention_mask,
'image_sizes': inputs.image_sizes,
}
def __del__(self):
__import__('shutil').rmtree(self.image_folder)
class PmcVqaEvalDataset(Dataset):
def __init__(
self, processor, data_size, instruction=DEFAULT_INSTSRUCTION, rank=0, world_size=1
):
# Download the file
file_path = hf_hub_download(
repo_id='xmcmic/PMC-VQA', # repository name
filename='images_2.zip', # file to download
repo_type='dataset', # specify it's a dataset repo
)
# file_path will be the local path where the file was downloaded
print(f'File downloaded to: {file_path}')
# unzip to temp folder
self.image_folder = Path(tempfile.mkdtemp())
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(self.image_folder)
data_files = {
'test': 'https://huggingface.co/datasets/xmcmic/PMC-VQA/resolve/main/test_2.csv',
}
split = 'test' if data_size is None else f'test[:{data_size}]'
self.annotations = load_dataset(
'xmcmic/PMC-VQA', data_files=data_files, split=split
).shard(num_shards=world_size, index=rank)
self.processor = processor
self.instruction = instruction
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
"""
{'index': 62,
'Figure_path': 'PMC8253867_Fig2_41.jpg',
'Caption': 'CT pulmonary angiogram reveals encasement and displacement of the left anterior descending coronary artery ( blue arrows ).',
'Question': ' What is the name of the artery encased and displaced in the image? ',
'Choice A': ' A: Right Coronary Artery ',
'Choice B': ' B: Left Anterior Descending Coronary Artery ',
'Choice C': ' C: Circumflex Coronary Artery ',
'Choice D': ' D: Superior Mesenteric Artery ',
'Answer': 'B',
'split': 'test'}
"""
annotation = self.annotations[idx]
image = Image.open(self.image_folder / 'figures' / annotation['Figure_path'])
question = annotation['Question']
choices = [annotation[f'Choice {chr(ord("A") + i)}'] for i in range(4)]
user_message = {
'role': 'user',
'content': '<|image_1|>' + '\n'.join([question] + choices + [self.instruction]),
}
prompt = self.processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True
)
answer = annotation['Answer']
inputs = self.processor(prompt, images=[image], return_tensors='pt')
unique_id = f'{annotation["index"]:010d}'
return {
'id': unique_id,
'input_ids': inputs.input_ids,
'input_image_embeds': inputs.input_image_embeds,
'image_attention_mask': inputs.image_attention_mask,
'image_sizes': inputs.image_sizes,
'answer': answer,
}
def __del__(self):
__import__('shutil').rmtree(self.image_folder)
def pad_sequence(sequences, padding_side='right', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def cat_with_pad(tensors, dim, padding_value=0):
"""
cat along dim, while pad to max for all other dims
"""
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in tensors[1:]
), 'All tensors must have the same number of dimensions'
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
out_size[dim] = sum(t.shape[dim] for t in tensors)
output = tensors[0].new_full(out_size, padding_value)
index = 0
for t in tensors:
# Create a slice list where every dimension except dim is full slice
slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
output[slices] = t
index += t.shape[dim]
return output
def pmc_vqa_collate_fn(batch):
input_ids_list = []
labels_list = []
input_image_embeds_list = []
image_attention_mask_list = []
image_sizes_list = []
for inputs in batch:
input_ids_list.append(inputs['input_ids'][0])
labels_list.append(inputs['labels'][0])
input_image_embeds_list.append(inputs['input_image_embeds'])
image_attention_mask_list.append(inputs['image_attention_mask'])
image_sizes_list.append(inputs['image_sizes'])
input_ids = pad_sequence(input_ids_list, padding_side='right', padding_value=0)
labels = pad_sequence(labels_list, padding_side='right', padding_value=0)
attention_mask = (input_ids != 0).long()
input_image_embeds = cat_with_pad(input_image_embeds_list, dim=0)
image_attention_mask = cat_with_pad(image_attention_mask_list, dim=0)
image_sizes = torch.cat(image_sizes_list)
return BatchFeature(
{
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'input_image_embeds': input_image_embeds,
'image_attention_mask': image_attention_mask,
'image_sizes': image_sizes,
'input_mode': 1, # vision mode
}
)
def pmc_vqa_eval_collate_fn(batch):
input_ids_list = []
input_image_embeds_list = []
image_attention_mask_list = []
image_sizes_list = []
all_unique_ids = []
all_answers = []
for inputs in batch:
input_ids_list.append(inputs['input_ids'][0])
input_image_embeds_list.append(inputs['input_image_embeds'])
image_attention_mask_list.append(inputs['image_attention_mask'])
image_sizes_list.append(inputs['image_sizes'])
all_unique_ids.append(inputs['id'])
all_answers.append(inputs['answer'])
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
attention_mask = (input_ids != 0).long()
input_image_embeds = cat_with_pad(input_image_embeds_list, dim=0)
image_attention_mask = cat_with_pad(image_attention_mask_list, dim=0)
image_sizes = torch.cat(image_sizes_list)
return (
all_unique_ids,
all_answers,
BatchFeature(
{
'input_ids': input_ids,
'attention_mask': attention_mask,
'input_image_embeds': input_image_embeds,
'image_attention_mask': image_attention_mask,
'image_sizes': image_sizes,
'input_mode': 1, # vision mode
}
),
)
def create_model(model_name_or_path, use_flash_attention=False):
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32,
_attn_implementation='flash_attention_2' if use_flash_attention else 'sdpa',
trust_remote_code=True,
).to('cuda')
# remove parameters irrelevant to vision tasks
del model.model.embed_tokens_extend.audio_embed # remove audio encoder
for layer in model.model.layers:
# remove audio lora
del layer.mlp.down_proj.lora_A.speech
del layer.mlp.down_proj.lora_B.speech
del layer.mlp.gate_up_proj.lora_A.speech
del layer.mlp.gate_up_proj.lora_B.speech
del layer.self_attn.o_proj.lora_A.speech
del layer.self_attn.o_proj.lora_B.speech
del layer.self_attn.qkv_proj.lora_A.speech
del layer.self_attn.qkv_proj.lora_B.speech
# TODO remove unused vision layers?
return model
@torch.no_grad()
def evaluate(
model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1
):
rank = int(os.environ.get('RANK', 0))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
model.eval()
all_answers = []
all_generated_texts = []
eval_dataloader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
collate_fn=pmc_vqa_eval_collate_fn,
shuffle=False,
drop_last=False,
num_workers=4,
prefetch_factor=2,
pin_memory=True,
)
for ids, answers, inputs in tqdm(
eval_dataloader, disable=(rank != 0) or disable_tqdm, desc='running eval'
):
all_answers.extend({'id': i, 'answer': a.strip().lower()} for i, a in zip(ids, answers))
inputs = inputs.to(f'cuda:{local_rank}')
generated_ids = model.generate(
**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64
)
input_len = inputs.input_ids.size(1)
generated_texts = processor.batch_decode(
generated_ids[:, input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
all_generated_texts.extend(
{'id': i, 'generated_text': g.strip().lower()} for i, g in zip(ids, generated_texts)
)
# gather outputs from all ranks
all_answers = gather_object(all_answers)
all_generated_texts = gather_object(all_generated_texts)
if rank == 0:
assert len(all_answers) == len(all_generated_texts)
acc = sum(
a['answer'] == g['generated_text'] for a, g in zip(all_answers, all_generated_texts)
) / len(all_answers)
if save_path:
with open(save_path, 'w') as f:
save_dict = {
'answers_unique': all_answers,
'generated_texts_unique': all_generated_texts,
'accuracy': acc,
}
json.dump(save_dict, f)
return acc
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_name_or_path',
type=str,
default='microsoft/Phi-4-multimodal-instruct',
help='Model name or path to load from',
)
parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention')
parser.add_argument('--output_dir', type=str, default='./output/', help='Output directory')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
parser.add_argument(
'--batch_size_per_gpu',
type=int,
default=1,
help='Batch size per GPU (adjust this to fit in GPU memory)',
)
parser.add_argument(
'--dynamic_hd',
type=int,
default=36,
help='Number of maximum image crops',
)
parser.add_argument(
'--num_train_epochs', type=int, default=1, help='Number of training epochs'
)
parser.add_argument('--learning_rate', type=float, default=4.0e-5, help='Learning rate')
parser.add_argument('--wd', type=float, default=0.01, help='Weight decay')
parser.add_argument('--no_tqdm', dest='tqdm', action='store_false', help='Disable tqdm')
parser.add_argument('--full_run', action='store_true', help='Run the full training and eval')
args = parser.parse_args()
accelerator = Accelerator()
with accelerator.local_main_process_first():
processor = AutoProcessor.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
dynamic_hd=args.dynamic_hd,
)
model = create_model(
args.model_name_or_path,
use_flash_attention=args.use_flash_attention,
)
# tune vision encoder and lora
model.set_lora_adapter('vision')
for param in model.model.embed_tokens_extend.image_embed.parameters():
param.requires_grad = True
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
train_dataset = PmcVqaTrainDataset(processor, data_size=None if args.full_run else _TRAIN_SIZE)
eval_dataset = PmcVqaEvalDataset(
processor,
data_size=None if args.full_run else _EVAL_SIZE,
rank=rank,
world_size=world_size,
)
num_gpus = accelerator.num_processes
print(f'training on {num_gpus} GPUs')
assert (
args.batch_size % (num_gpus * args.batch_size_per_gpu) == 0
), 'Batch size must be divisible by the number of GPUs'
gradient_accumulation_steps = args.batch_size // (num_gpus * args.batch_size_per_gpu)
if args.use_flash_attention:
fp16 = False
bf16 = True
else:
fp16 = True
bf16 = False
# hard coded training args
training_args = TrainingArguments(
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size_per_gpu,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
gradient_accumulation_steps=gradient_accumulation_steps,
optim='adamw_torch',
adam_beta1=0.9,
adam_beta2=0.95,
adam_epsilon=1e-7,
learning_rate=args.learning_rate,
weight_decay=args.wd,
max_grad_norm=1.0,
lr_scheduler_type='linear',
warmup_steps=50,
logging_steps=10,
output_dir=args.output_dir,
save_strategy='no',
save_total_limit=10,
save_only_model=True,
bf16=bf16,
fp16=fp16,
remove_unused_columns=False,
report_to='none',
deepspeed=None,
disable_tqdm=not args.tqdm,
dataloader_num_workers=4,
ddp_find_unused_parameters=True, # for unused SigLIP layers
)
# eval before fine-tuning
out_path = Path(training_args.output_dir)
out_path.mkdir(parents=True, exist_ok=True)
acc = evaluate(
model,
processor,
eval_dataset,
save_path=out_path / 'eval_before.json',
disable_tqdm=not args.tqdm,
eval_batch_size=args.batch_size_per_gpu,
)
if accelerator.is_main_process:
print(f'Accuracy before finetuning: {acc}')
trainer = Trainer(
model=model,
args=training_args,
data_collator=pmc_vqa_collate_fn,
train_dataset=train_dataset,
)
trainer.train()
trainer.save_model()
accelerator.wait_for_everyone()
# eval after fine-tuning (load saved checkpoint)
# first try to clear GPU memory
del model
del trainer
__import__('gc').collect()
torch.cuda.empty_cache()
# reload the model for inference
model = AutoModelForCausalLM.from_pretrained(
training_args.output_dir,
torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32,
trust_remote_code=True,
_attn_implementation='flash_attention_2' if args.use_flash_attention else 'sdpa',
).to('cuda')
acc = evaluate(
model,
processor,
eval_dataset,
save_path=out_path / 'eval_after.json',
disable_tqdm=not args.tqdm,
eval_batch_size=args.batch_size_per_gpu,
)
if accelerator.is_main_process:
print(f'Accuracy after finetuning: {acc}')
if __name__ == '__main__':
main() |