|
""" |
|
finetune Phi-4-multimodal-instruct on an image task |
|
|
|
scipy==1.15.1 |
|
peft==0.13.2 |
|
backoff==2.2.1 |
|
transformers==4.47.0 |
|
accelerate==1.3.0 |
|
""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
import tempfile |
|
import zipfile |
|
from pathlib import Path |
|
|
|
import torch |
|
from accelerate import Accelerator |
|
from accelerate.utils import gather_object |
|
from datasets import load_dataset |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoProcessor, |
|
BatchFeature, |
|
Trainer, |
|
TrainingArguments, |
|
) |
|
|
|
DEFAULT_INSTSRUCTION = "Answer with the option's letter from the given choices directly." |
|
_IGNORE_INDEX = -100 |
|
_TRAIN_SIZE = 8000 |
|
_EVAL_SIZE = 500 |
|
_MAX_TRAINING_LENGTH = 8192 |
|
|
|
|
|
class PmcVqaTrainDataset(Dataset): |
|
def __init__(self, processor, data_size, instruction=DEFAULT_INSTSRUCTION): |
|
|
|
file_path = hf_hub_download( |
|
repo_id='xmcmic/PMC-VQA', |
|
filename='images_2.zip', |
|
repo_type='dataset', |
|
) |
|
|
|
|
|
print(f'File downloaded to: {file_path}') |
|
|
|
|
|
self.image_folder = Path(tempfile.mkdtemp()) |
|
with zipfile.ZipFile(file_path, 'r') as zip_ref: |
|
zip_ref.extractall(self.image_folder) |
|
|
|
data_files = { |
|
'train': 'https://huggingface.co/datasets/xmcmic/PMC-VQA/resolve/main/train_2.csv', |
|
} |
|
split = 'train' if data_size is None else f'train[:{data_size}]' |
|
self.annotations = load_dataset('xmcmic/PMC-VQA', data_files=data_files, split=split) |
|
self.processor = processor |
|
self.instruction = instruction |
|
|
|
def __len__(self): |
|
return len(self.annotations) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
{'index': 35, |
|
'Figure_path': 'PMC8253797_Fig4_11.jpg', |
|
'Caption': 'A slightly altered cell . (c-c‴) A highly altered cell as seen from 4 different angles . Note mitochondria/mitochondrial networks (green), Golgi complexes (red), cell nuclei (light blue) and the cell outline (yellow).', |
|
'Question': ' What color is used to label the Golgi complexes in the image?', |
|
'Choice A': ' A: Green ', |
|
'Choice B': ' B: Red ', |
|
'Choice C': ' C: Light blue ', |
|
'Choice D': ' D: Yellow', |
|
'Answer': 'B', |
|
'split': 'train'} |
|
""" |
|
annotation = self.annotations[idx] |
|
image = Image.open(self.image_folder / 'figures' / annotation['Figure_path']) |
|
question = annotation['Question'] |
|
choices = [annotation[f'Choice {chr(ord("A") + i)}'] for i in range(4)] |
|
user_message = { |
|
'role': 'user', |
|
'content': '<|image_1|>' + '\n'.join([question] + choices + [self.instruction]), |
|
} |
|
prompt = self.processor.tokenizer.apply_chat_template( |
|
[user_message], tokenize=False, add_generation_prompt=True |
|
) |
|
answer = f'{annotation["Answer"]}<|end|><|endoftext|>' |
|
inputs = self.processor(prompt, images=[image], return_tensors='pt') |
|
|
|
answer_ids = self.processor.tokenizer(answer, return_tensors='pt').input_ids |
|
|
|
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) |
|
labels = torch.full_like(input_ids, _IGNORE_INDEX) |
|
labels[:, -answer_ids.shape[1] :] = answer_ids |
|
|
|
if input_ids.size(1) > _MAX_TRAINING_LENGTH: |
|
input_ids = input_ids[:, :_MAX_TRAINING_LENGTH] |
|
labels = labels[:, :_MAX_TRAINING_LENGTH] |
|
if torch.all(labels == _IGNORE_INDEX).item(): |
|
|
|
labels[:, -1] = self.processor.tokenizer.eos_token_id |
|
|
|
return { |
|
'input_ids': input_ids, |
|
'labels': labels, |
|
'input_image_embeds': inputs.input_image_embeds, |
|
'image_attention_mask': inputs.image_attention_mask, |
|
'image_sizes': inputs.image_sizes, |
|
} |
|
|
|
def __del__(self): |
|
__import__('shutil').rmtree(self.image_folder) |
|
|
|
|
|
class PmcVqaEvalDataset(Dataset): |
|
def __init__( |
|
self, processor, data_size, instruction=DEFAULT_INSTSRUCTION, rank=0, world_size=1 |
|
): |
|
|
|
file_path = hf_hub_download( |
|
repo_id='xmcmic/PMC-VQA', |
|
filename='images_2.zip', |
|
repo_type='dataset', |
|
) |
|
|
|
|
|
print(f'File downloaded to: {file_path}') |
|
|
|
|
|
self.image_folder = Path(tempfile.mkdtemp()) |
|
with zipfile.ZipFile(file_path, 'r') as zip_ref: |
|
zip_ref.extractall(self.image_folder) |
|
|
|
data_files = { |
|
'test': 'https://huggingface.co/datasets/xmcmic/PMC-VQA/resolve/main/test_2.csv', |
|
} |
|
split = 'test' if data_size is None else f'test[:{data_size}]' |
|
self.annotations = load_dataset( |
|
'xmcmic/PMC-VQA', data_files=data_files, split=split |
|
).shard(num_shards=world_size, index=rank) |
|
self.processor = processor |
|
self.instruction = instruction |
|
|
|
def __len__(self): |
|
return len(self.annotations) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
{'index': 62, |
|
'Figure_path': 'PMC8253867_Fig2_41.jpg', |
|
'Caption': 'CT pulmonary angiogram reveals encasement and displacement of the left anterior descending coronary artery ( blue arrows ).', |
|
'Question': ' What is the name of the artery encased and displaced in the image? ', |
|
'Choice A': ' A: Right Coronary Artery ', |
|
'Choice B': ' B: Left Anterior Descending Coronary Artery ', |
|
'Choice C': ' C: Circumflex Coronary Artery ', |
|
'Choice D': ' D: Superior Mesenteric Artery ', |
|
'Answer': 'B', |
|
'split': 'test'} |
|
""" |
|
annotation = self.annotations[idx] |
|
image = Image.open(self.image_folder / 'figures' / annotation['Figure_path']) |
|
question = annotation['Question'] |
|
choices = [annotation[f'Choice {chr(ord("A") + i)}'] for i in range(4)] |
|
user_message = { |
|
'role': 'user', |
|
'content': '<|image_1|>' + '\n'.join([question] + choices + [self.instruction]), |
|
} |
|
prompt = self.processor.tokenizer.apply_chat_template( |
|
[user_message], tokenize=False, add_generation_prompt=True |
|
) |
|
answer = annotation['Answer'] |
|
inputs = self.processor(prompt, images=[image], return_tensors='pt') |
|
|
|
unique_id = f'{annotation["index"]:010d}' |
|
return { |
|
'id': unique_id, |
|
'input_ids': inputs.input_ids, |
|
'input_image_embeds': inputs.input_image_embeds, |
|
'image_attention_mask': inputs.image_attention_mask, |
|
'image_sizes': inputs.image_sizes, |
|
'answer': answer, |
|
} |
|
|
|
def __del__(self): |
|
__import__('shutil').rmtree(self.image_folder) |
|
|
|
|
|
def pad_sequence(sequences, padding_side='right', padding_value=0): |
|
""" |
|
Pad a list of sequences to the same length. |
|
sequences: list of tensors in [seq_len, *] shape |
|
""" |
|
assert padding_side in ['right', 'left'] |
|
max_size = sequences[0].size() |
|
trailing_dims = max_size[1:] |
|
max_len = max(len(seq) for seq in sequences) |
|
batch_size = len(sequences) |
|
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) |
|
for i, seq in enumerate(sequences): |
|
length = seq.size(0) |
|
if padding_side == 'right': |
|
output.data[i, :length] = seq |
|
else: |
|
output.data[i, -length:] = seq |
|
return output |
|
|
|
|
|
def cat_with_pad(tensors, dim, padding_value=0): |
|
""" |
|
cat along dim, while pad to max for all other dims |
|
""" |
|
ndim = tensors[0].dim() |
|
assert all( |
|
t.dim() == ndim for t in tensors[1:] |
|
), 'All tensors must have the same number of dimensions' |
|
|
|
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] |
|
out_size[dim] = sum(t.shape[dim] for t in tensors) |
|
output = tensors[0].new_full(out_size, padding_value) |
|
|
|
index = 0 |
|
for t in tensors: |
|
|
|
slices = [slice(0, t.shape[d]) for d in range(ndim)] |
|
|
|
slices[dim] = slice(index, index + t.shape[dim]) |
|
|
|
output[slices] = t |
|
index += t.shape[dim] |
|
|
|
return output |
|
|
|
|
|
def pmc_vqa_collate_fn(batch): |
|
input_ids_list = [] |
|
labels_list = [] |
|
input_image_embeds_list = [] |
|
image_attention_mask_list = [] |
|
image_sizes_list = [] |
|
for inputs in batch: |
|
input_ids_list.append(inputs['input_ids'][0]) |
|
labels_list.append(inputs['labels'][0]) |
|
input_image_embeds_list.append(inputs['input_image_embeds']) |
|
image_attention_mask_list.append(inputs['image_attention_mask']) |
|
image_sizes_list.append(inputs['image_sizes']) |
|
|
|
input_ids = pad_sequence(input_ids_list, padding_side='right', padding_value=0) |
|
labels = pad_sequence(labels_list, padding_side='right', padding_value=0) |
|
attention_mask = (input_ids != 0).long() |
|
input_image_embeds = cat_with_pad(input_image_embeds_list, dim=0) |
|
image_attention_mask = cat_with_pad(image_attention_mask_list, dim=0) |
|
image_sizes = torch.cat(image_sizes_list) |
|
|
|
return BatchFeature( |
|
{ |
|
'input_ids': input_ids, |
|
'labels': labels, |
|
'attention_mask': attention_mask, |
|
'input_image_embeds': input_image_embeds, |
|
'image_attention_mask': image_attention_mask, |
|
'image_sizes': image_sizes, |
|
'input_mode': 1, |
|
} |
|
) |
|
|
|
|
|
def pmc_vqa_eval_collate_fn(batch): |
|
input_ids_list = [] |
|
input_image_embeds_list = [] |
|
image_attention_mask_list = [] |
|
image_sizes_list = [] |
|
all_unique_ids = [] |
|
all_answers = [] |
|
for inputs in batch: |
|
input_ids_list.append(inputs['input_ids'][0]) |
|
input_image_embeds_list.append(inputs['input_image_embeds']) |
|
image_attention_mask_list.append(inputs['image_attention_mask']) |
|
image_sizes_list.append(inputs['image_sizes']) |
|
all_unique_ids.append(inputs['id']) |
|
all_answers.append(inputs['answer']) |
|
|
|
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) |
|
attention_mask = (input_ids != 0).long() |
|
input_image_embeds = cat_with_pad(input_image_embeds_list, dim=0) |
|
image_attention_mask = cat_with_pad(image_attention_mask_list, dim=0) |
|
image_sizes = torch.cat(image_sizes_list) |
|
|
|
return ( |
|
all_unique_ids, |
|
all_answers, |
|
BatchFeature( |
|
{ |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask, |
|
'input_image_embeds': input_image_embeds, |
|
'image_attention_mask': image_attention_mask, |
|
'image_sizes': image_sizes, |
|
'input_mode': 1, |
|
} |
|
), |
|
) |
|
|
|
|
|
def create_model(model_name_or_path, use_flash_attention=False): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32, |
|
_attn_implementation='flash_attention_2' if use_flash_attention else 'sdpa', |
|
trust_remote_code=True, |
|
).to('cuda') |
|
|
|
del model.model.embed_tokens_extend.audio_embed |
|
for layer in model.model.layers: |
|
|
|
del layer.mlp.down_proj.lora_A.speech |
|
del layer.mlp.down_proj.lora_B.speech |
|
del layer.mlp.gate_up_proj.lora_A.speech |
|
del layer.mlp.gate_up_proj.lora_B.speech |
|
del layer.self_attn.o_proj.lora_A.speech |
|
del layer.self_attn.o_proj.lora_B.speech |
|
del layer.self_attn.qkv_proj.lora_A.speech |
|
del layer.self_attn.qkv_proj.lora_B.speech |
|
|
|
|
|
|
|
return model |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate( |
|
model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1 |
|
): |
|
rank = int(os.environ.get('RANK', 0)) |
|
local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
|
|
|
model.eval() |
|
all_answers = [] |
|
all_generated_texts = [] |
|
|
|
eval_dataloader = torch.utils.data.DataLoader( |
|
eval_dataset, |
|
batch_size=eval_batch_size, |
|
collate_fn=pmc_vqa_eval_collate_fn, |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=4, |
|
prefetch_factor=2, |
|
pin_memory=True, |
|
) |
|
for ids, answers, inputs in tqdm( |
|
eval_dataloader, disable=(rank != 0) or disable_tqdm, desc='running eval' |
|
): |
|
all_answers.extend({'id': i, 'answer': a.strip().lower()} for i, a in zip(ids, answers)) |
|
|
|
inputs = inputs.to(f'cuda:{local_rank}') |
|
generated_ids = model.generate( |
|
**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64 |
|
) |
|
|
|
input_len = inputs.input_ids.size(1) |
|
generated_texts = processor.batch_decode( |
|
generated_ids[:, input_len:], |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False, |
|
) |
|
all_generated_texts.extend( |
|
{'id': i, 'generated_text': g.strip().lower()} for i, g in zip(ids, generated_texts) |
|
) |
|
|
|
|
|
all_answers = gather_object(all_answers) |
|
all_generated_texts = gather_object(all_generated_texts) |
|
|
|
if rank == 0: |
|
assert len(all_answers) == len(all_generated_texts) |
|
acc = sum( |
|
a['answer'] == g['generated_text'] for a, g in zip(all_answers, all_generated_texts) |
|
) / len(all_answers) |
|
if save_path: |
|
with open(save_path, 'w') as f: |
|
save_dict = { |
|
'answers_unique': all_answers, |
|
'generated_texts_unique': all_generated_texts, |
|
'accuracy': acc, |
|
} |
|
json.dump(save_dict, f) |
|
|
|
return acc |
|
return None |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--model_name_or_path', |
|
type=str, |
|
default='microsoft/Phi-4-multimodal-instruct', |
|
help='Model name or path to load from', |
|
) |
|
parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention') |
|
parser.add_argument('--output_dir', type=str, default='./output/', help='Output directory') |
|
parser.add_argument('--batch_size', type=int, default=16, help='Batch size') |
|
parser.add_argument( |
|
'--batch_size_per_gpu', |
|
type=int, |
|
default=1, |
|
help='Batch size per GPU (adjust this to fit in GPU memory)', |
|
) |
|
parser.add_argument( |
|
'--dynamic_hd', |
|
type=int, |
|
default=36, |
|
help='Number of maximum image crops', |
|
) |
|
parser.add_argument( |
|
'--num_train_epochs', type=int, default=1, help='Number of training epochs' |
|
) |
|
parser.add_argument('--learning_rate', type=float, default=4.0e-5, help='Learning rate') |
|
parser.add_argument('--wd', type=float, default=0.01, help='Weight decay') |
|
parser.add_argument('--no_tqdm', dest='tqdm', action='store_false', help='Disable tqdm') |
|
parser.add_argument('--full_run', action='store_true', help='Run the full training and eval') |
|
args = parser.parse_args() |
|
|
|
accelerator = Accelerator() |
|
|
|
with accelerator.local_main_process_first(): |
|
processor = AutoProcessor.from_pretrained( |
|
args.model_name_or_path, |
|
trust_remote_code=True, |
|
dynamic_hd=args.dynamic_hd, |
|
) |
|
model = create_model( |
|
args.model_name_or_path, |
|
use_flash_attention=args.use_flash_attention, |
|
) |
|
|
|
model.set_lora_adapter('vision') |
|
for param in model.model.embed_tokens_extend.image_embed.parameters(): |
|
param.requires_grad = True |
|
|
|
rank = int(os.environ.get('RANK', 0)) |
|
world_size = int(os.environ.get('WORLD_SIZE', 1)) |
|
|
|
train_dataset = PmcVqaTrainDataset(processor, data_size=None if args.full_run else _TRAIN_SIZE) |
|
eval_dataset = PmcVqaEvalDataset( |
|
processor, |
|
data_size=None if args.full_run else _EVAL_SIZE, |
|
rank=rank, |
|
world_size=world_size, |
|
) |
|
|
|
num_gpus = accelerator.num_processes |
|
print(f'training on {num_gpus} GPUs') |
|
assert ( |
|
args.batch_size % (num_gpus * args.batch_size_per_gpu) == 0 |
|
), 'Batch size must be divisible by the number of GPUs' |
|
gradient_accumulation_steps = args.batch_size // (num_gpus * args.batch_size_per_gpu) |
|
|
|
if args.use_flash_attention: |
|
fp16 = False |
|
bf16 = True |
|
else: |
|
fp16 = True |
|
bf16 = False |
|
|
|
|
|
training_args = TrainingArguments( |
|
num_train_epochs=args.num_train_epochs, |
|
per_device_train_batch_size=args.batch_size_per_gpu, |
|
gradient_checkpointing=True, |
|
gradient_checkpointing_kwargs={'use_reentrant': False}, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
optim='adamw_torch', |
|
adam_beta1=0.9, |
|
adam_beta2=0.95, |
|
adam_epsilon=1e-7, |
|
learning_rate=args.learning_rate, |
|
weight_decay=args.wd, |
|
max_grad_norm=1.0, |
|
lr_scheduler_type='linear', |
|
warmup_steps=50, |
|
logging_steps=10, |
|
output_dir=args.output_dir, |
|
save_strategy='no', |
|
save_total_limit=10, |
|
save_only_model=True, |
|
bf16=bf16, |
|
fp16=fp16, |
|
remove_unused_columns=False, |
|
report_to='none', |
|
deepspeed=None, |
|
disable_tqdm=not args.tqdm, |
|
dataloader_num_workers=4, |
|
ddp_find_unused_parameters=True, |
|
) |
|
|
|
|
|
out_path = Path(training_args.output_dir) |
|
out_path.mkdir(parents=True, exist_ok=True) |
|
|
|
acc = evaluate( |
|
model, |
|
processor, |
|
eval_dataset, |
|
save_path=out_path / 'eval_before.json', |
|
disable_tqdm=not args.tqdm, |
|
eval_batch_size=args.batch_size_per_gpu, |
|
) |
|
if accelerator.is_main_process: |
|
print(f'Accuracy before finetuning: {acc}') |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=pmc_vqa_collate_fn, |
|
train_dataset=train_dataset, |
|
) |
|
trainer.train() |
|
trainer.save_model() |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
del model |
|
del trainer |
|
__import__('gc').collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
training_args.output_dir, |
|
torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32, |
|
trust_remote_code=True, |
|
_attn_implementation='flash_attention_2' if args.use_flash_attention else 'sdpa', |
|
).to('cuda') |
|
|
|
acc = evaluate( |
|
model, |
|
processor, |
|
eval_dataset, |
|
save_path=out_path / 'eval_after.json', |
|
disable_tqdm=not args.tqdm, |
|
eval_batch_size=args.batch_size_per_gpu, |
|
) |
|
if accelerator.is_main_process: |
|
print(f'Accuracy after finetuning: {acc}') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |