|
from datasets import Dataset, DatasetDict, Image |
|
import pandas as pd |
|
import os |
|
|
|
import torch |
|
from peft import LoraConfig |
|
from transformers import AutoProcessor, BitsAndBytesConfig |
|
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq |
|
from datetime import datetime |
|
import evaluate |
|
|
|
TRAIN_SAMPLES = 1000 |
|
TEST_SAMPLES = 200 |
|
TEST_SIZE = 0.166 |
|
|
|
|
|
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" |
|
df = pd.read_csv(df_path) |
|
df.dropna(inplace=True) |
|
df["id"] = range(df.shape[0]) |
|
df["query"] = "What is shown in this image?" |
|
|
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
image_paths = [root_dir + img for img in df.filename] |
|
|
|
|
|
df_path_2 = "/mnt/data1/Datasets/AlphaPen/" + "training_b2.csv" |
|
df_2 = pd.read_csv(df_path_2) |
|
df_2.dropna(inplace=True) |
|
df_2["id"] = range(df_2.shape[0]) |
|
df_2["query"] = "What is shown in this image?" |
|
|
|
root_dir_2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_" |
|
image_paths_2 = [root_dir_2 + img for img in df_2.filename] |
|
|
|
ids = range(df.shape[0] + df_2.shape[0]) |
|
queries = df['query'].tolist() + df_2['query'].tolist() |
|
answers = df['text'].tolist() + df_2['text'].tolist() |
|
|
|
|
|
dataset_dict = { |
|
'id': ids, |
|
'image': image_paths + image_paths_2, |
|
'query': queries, |
|
'answers': answers |
|
} |
|
|
|
|
|
dataset = Dataset.from_dict(dataset_dict) |
|
|
|
|
|
dataset = dataset.cast_column("image", Image()) |
|
|
|
|
|
split_dataset = dataset.train_test_split(test_size=TEST_SIZE, shuffle=False) |
|
|
|
train_dataset = split_dataset["train"] |
|
eval_dataset = split_dataset["test"] |
|
print(len(train_dataset)) |
|
|
|
|
|
|
|
os.environ["WANDB_PROJECT"]="Alphapen" |
|
|
|
|
|
|
|
model_id = "HuggingFaceM4/idefics2-8b" |
|
|
|
DEVICE = "cuda:0" |
|
USE_LORA = False |
|
USE_QLORA = True |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
model_id, |
|
do_image_splitting=False |
|
) |
|
|
|
|
|
if USE_QLORA or USE_LORA: |
|
lora_config = LoraConfig( |
|
r=64, |
|
lora_alpha=16, |
|
lora_dropout=0.1, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_modules = '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$', |
|
use_dora=False if USE_QLORA else True, |
|
init_lora_weights="gaussian" |
|
) |
|
if USE_QLORA: |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
model = AutoModelForVision2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
quantization_config=bnb_config if USE_QLORA else None, |
|
trust_remote_code=True |
|
) |
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
model.config.max_length= 128 |
|
model.add_adapter(lora_config) |
|
model.enable_adapters() |
|
else: |
|
model = AutoModelForVision2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
_attn_implementation="flash_attention_2", |
|
trust_remote_code=True |
|
).to(DEVICE) |
|
|
|
|
|
|
|
import random |
|
|
|
class MyDataCollator: |
|
def __init__(self, processor): |
|
self.processor = processor |
|
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ |
|
processor.tokenizer.additional_special_tokens.index("<image>") |
|
] |
|
|
|
def __call__(self, examples): |
|
texts = [] |
|
images = [] |
|
for example in examples: |
|
image = example["image"] |
|
|
|
question = example["query"] |
|
answer = example["answers"] |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "OCR the text in the image."}, |
|
{"type": "image"}, |
|
{"type": "text", "text": question} |
|
] |
|
}, |
|
{ |
|
"role": "assistant", |
|
"content": [ |
|
{"type": "text", "text": answer} |
|
] |
|
} |
|
] |
|
text = processor.apply_chat_template(messages, add_generation_prompt=False) |
|
texts.append(text.strip()) |
|
images.append([image]) |
|
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True) |
|
|
|
labels = batch["input_ids"].clone() |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
data_collator = MyDataCollator(processor) |
|
|
|
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments |
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
output_dir = "idefics2", |
|
learning_rate = 2e-4, |
|
fp16 = True, |
|
per_device_train_batch_size = 8, |
|
per_device_eval_batch_size = 8, |
|
gradient_accumulation_steps = 2, |
|
dataloader_pin_memory = False, |
|
save_total_limit = 3, |
|
eval_strategy ="steps", |
|
save_strategy = "steps", |
|
eval_steps = 200, |
|
save_steps = 10000, |
|
max_steps = 50000, |
|
logging_steps = 10, |
|
remove_unused_columns = False, |
|
push_to_hub=True, |
|
label_names = ["labels"], |
|
load_best_model_at_end = False, |
|
report_to = "wandb", |
|
optim = "paged_adamw_8bit", |
|
run_name=f"idefics2-vision-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}", |
|
hub_model_id="hadrakey/alphapen_idefics2_finetune_v1", |
|
) |
|
|
|
def compute_metrics(pred): |
|
|
|
cer_metric = evaluate.load("cer") |
|
|
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
|
|
|
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
pred_str = [word.lower() for word in pred_str] |
|
|
|
|
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
label_str = [word.lower() for word in label_str] |
|
|
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
return {"cer": cer} |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model = model, |
|
args = training_args, |
|
data_collator = data_collator, |
|
train_dataset = train_dataset, |
|
eval_dataset = eval_dataset, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
trainer.train() |