VaqAndOkvqa / app.py
DDingcheol's picture
Rename app.py.py to app.py
7172545
raw
history blame
2.8 kB
#ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋Œ์•„๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก ๋ฐ”๊พธ์–ด ๋ณด์•˜์Œ
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments
from datasets import load_dataset
from collections import defaultdict
# ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train')
dataset = dataset_load['train'].select(range(300))
# ๋ถˆํ•„์š”ํ•œ ํŠน์„ฑ ์„ ํƒ
selected_features = ['image', 'answers', 'question']
selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features})
# ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ
answers_to_id = defaultdict(lambda: len(answers_to_id))
selected_dataset = selected_dataset.map(lambda ex: {
'answers': [answers_to_id[ans] for ans in ex['answers']],
'question': ex['question'],
'image': ex['image']
})
id_to_answers = {v: k for k, v in answers_to_id.items()}
id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)}
selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]),
'question': ex['question'],
'image': ex['image']})
flattened_features = []
for ex in selected_dataset:
flattened_example = {
'answers': ex['answers'],
'question': ex['question'],
'image': ex['image'],
}
flattened_features.append(flattened_example)
# ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model_name = 'microsoft/git-base-vqav2'
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ํ•™์Šต
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
def preprocess_function(examples):
tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True)
return {
'input_ids': tokenized_inputs['input_ids'],
'attention_mask': tokenized_inputs['attention_mask'],
'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']),
'pixel_mask': [1] * len(tokenized_inputs['input_ids']),
'labels': [[label] for label in examples['answers']]
}
dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300))
ok_vqa_dataset = dataset.map(preprocess_function, batched=True)
ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=20,
per_device_train_batch_size=4,
logging_steps=500,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ok_vqa_dataset
)
# ๋ชจ๋ธ ํ•™์Šต
trainer.train()