Spaces:
Paused
A newer version of the Gradio SDK is available:
5.23.3
์ด๋ฏธ์ง ์บก์ ๋[[image-captioning]]
[[open-in-colab]]
์ด๋ฏธ์ง ์บก์ ๋(Image captioning)์ ์ฃผ์ด์ง ์ด๋ฏธ์ง์ ๋ํ ์บก์ ์ ์์ธกํ๋ ์์ ์ ๋๋ค. ์ด๋ฏธ์ง ์บก์ ๋์ ์๊ฐ ์ฅ์ ์ธ์ด ๋ค์ํ ์ํฉ์ ํ์ํ๋ ๋ฐ ๋์์ ์ค ์ ์๋๋ก ์๊ฐ ์ฅ์ ์ธ์ ๋ณด์กฐํ๋ ๋ฑ ์ค์ํ์์ ํํ ํ์ฉ๋ฉ๋๋ค. ๋ฐ๋ผ์ ์ด๋ฏธ์ง ์บก์ ๋์ ์ด๋ฏธ์ง๋ฅผ ์ค๋ช ํจ์ผ๋ก์จ ์ฌ๋๋ค์ ์ฝํ ์ธ ์ ๊ทผ์ฑ์ ๊ฐ์ ํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค.
์ด ๊ฐ์ด๋์์๋ ์๊ฐํ ๋ด์ฉ์ ์๋์ ๊ฐ์ต๋๋ค:
- ์ด๋ฏธ์ง ์บก์ ๋ ๋ชจ๋ธ์ ํ์ธํ๋ํฉ๋๋ค.
- ํ์ธํ๋๋ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํฉ๋๋ค.
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
pip install transformers datasets evaluate -q
pip install jiwer -q
Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ๋ฉด ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ ์ ์์ต๋๋ค. ํ ํฐ์ ์ ๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์.
from huggingface_hub import notebook_login
notebook_login()
ํฌ์ผ๋ชฌ BLIP ์บก์ ๋ฐ์ดํฐ์ธํธ ๊ฐ์ ธ์ค๊ธฐ[[load-the-pokmon-blip-captions-dataset]]
{์ด๋ฏธ์ง-์บก์ } ์์ผ๋ก ๊ตฌ์ฑ๋ ๋ฐ์ดํฐ์ธํธ๋ฅผ ๊ฐ์ ธ์ค๋ ค๋ฉด ๐ค Dataset ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. PyTorch์์ ์์ ๋ง์ ์ด๋ฏธ์ง ์บก์ ๋ฐ์ดํฐ์ธํธ๋ฅผ ๋ง๋ค๋ ค๋ฉด ์ด ๋ ธํธ๋ถ์ ์ฐธ์กฐํ์ธ์.
from datasets import load_dataset
ds = load_dataset("lambdalabs/pokemon-blip-captions")
ds
DatasetDict({
train: Dataset({
features: ['image', 'text'],
num_rows: 833
})
})
์ด ๋ฐ์ดํฐ์ธํธ๋ image
์ text
๋ผ๋ ๋ ํน์ฑ์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
๋ง์ ์ด๋ฏธ์ง ์บก์ ๋ฐ์ดํฐ์ธํธ์๋ ์ด๋ฏธ์ง๋น ์ฌ๋ฌ ๊ฐ์ ์บก์ ์ด ํฌํจ๋์ด ์์ต๋๋ค. ์ด๋ฌํ ๊ฒฝ์ฐ, ์ผ๋ฐ์ ์ผ๋ก ํ์ต ์ค์ ์ฌ์ฉ ๊ฐ๋ฅํ ์บก์ ์ค์์ ๋ฌด์์๋ก ์ํ์ ์ถ์ถํฉ๋๋ค.
[~datasets.Dataset.train_test_split] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ์ธํธ์ ํ์ต ๋ถํ ์ ํ์ต ๋ฐ ํ ์คํธ ์ธํธ๋ก ๋๋๋๋ค:
ds = ds["train"].train_test_split(test_size=0.1)
train_ds = ds["train"]
test_ds = ds["test"]
ํ์ต ์ธํธ์ ์ํ ๋ช ๊ฐ๋ฅผ ์๊ฐํํด ๋ด ์๋ค. Let's visualize a couple of samples from the training set.
from textwrap import wrap
import matplotlib.pyplot as plt
import numpy as np
def plot_images(images, captions):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
caption = captions[i]
caption = "\n".join(wrap(caption, 12))
plt.title(caption)
plt.imshow(images[i])
plt.axis("off")
sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)]
sample_captions = [train_ds[i]["text"] for i in range(5)]
plot_images(sample_images_to_visualize, sample_captions)

๋ฐ์ดํฐ์ธํธ ์ ์ฒ๋ฆฌ[[preprocess-the-dataset]]
๋ฐ์ดํฐ์ธํธ์๋ ์ด๋ฏธ์ง์ ํ ์คํธ๋ผ๋ ๋ ๊ฐ์ง ์์์ด ์๊ธฐ ๋๋ฌธ์, ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์์ ์ด๋ฏธ์ง์ ์บก์ ์ ๋ชจ๋ ์ ์ฒ๋ฆฌํฉ๋๋ค.
์ ์ฒ๋ฆฌ ์์ ์ ์ํด, ํ์ธํ๋ํ๋ ค๋ ๋ชจ๋ธ์ ์ฐ๊ฒฐ๋ ํ๋ก์ธ์ ํด๋์ค๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
from transformers import AutoProcessor
checkpoint = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint)
ํ๋ก์ธ์๋ ๋ด๋ถ์ ์ผ๋ก ํฌ๊ธฐ ์กฐ์ ๋ฐ ํฝ์ ํฌ๊ธฐ ์กฐ์ ์ ํฌํจํ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ๋ฅผ ์ํํ๊ณ ์บก์ ์ ํ ํฐํํฉ๋๋ค.
def transforms(example_batch):
images = [x for x in example_batch["image"]]
captions = [x for x in example_batch["text"]]
inputs = processor(images=images, text=captions, padding="max_length")
inputs.update({"labels": inputs["input_ids"]})
return inputs
train_ds.set_transform(transforms)
test_ds.set_transform(transforms)
๋ฐ์ดํฐ์ธํธ๊ฐ ์ค๋น๋์์ผ๋ ์ด์ ํ์ธํ๋์ ์ํด ๋ชจ๋ธ์ ์ค์ ํ ์ ์์ต๋๋ค.
๊ธฐ๋ณธ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ[[load-a-base-model]]
"microsoft/git-base"๋ฅผ AutoModelForCausalLM
๊ฐ์ฒด๋ก ๊ฐ์ ธ์ต๋๋ค.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(checkpoint)
ํ๊ฐ[[evaluate]]
์ด๋ฏธ์ง ์บก์ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก Rouge ์ ์ ๋๋ ๋จ์ด ์ค๋ฅ์จ(Word Error Rate)๋ก ํ๊ฐํฉ๋๋ค. ์ด ๊ฐ์ด๋์์๋ ๋จ์ด ์ค๋ฅ์จ(WER)์ ์ฌ์ฉํฉ๋๋ค.
์ด๋ฅผ ์ํด ๐ค Evaluate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. WER์ ์ ์ฌ์ ์ ํ ์ฌํญ ๋ฐ ๊ธฐํ ๋ฌธ์ ์ ์ ์ด ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์.
from evaluate import load
import torch
wer = load("wer")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predicted = logits.argmax(-1)
decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
return {"wer_score": wer_score}
ํ์ต![[train!]]
์ด์ ๋ชจ๋ธ ํ์ธํ๋์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค. ์ด๋ฅผ ์ํด ๐ค [Trainer
]๋ฅผ ์ฌ์ฉํฉ๋๋ค.
๋จผ์ , [TrainingArguments
]๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต ์ธ์๋ฅผ ์ ์ํฉ๋๋ค.
from transformers import TrainingArguments, Trainer
model_name = checkpoint.split("/")[1]
training_args = TrainingArguments(
output_dir=f"{model_name}-pokemon",
learning_rate=5e-5,
num_train_epochs=50,
fp16=True,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
gradient_accumulation_steps=2,
save_total_limit=3,
evaluation_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
logging_steps=50,
remove_unused_columns=False,
push_to_hub=True,
label_names=["labels"],
load_best_model_at_end=True,
)
ํ์ต ์ธ์๋ฅผ ๋ฐ์ดํฐ์ธํธ, ๋ชจ๋ธ๊ณผ ํจ๊ป ๐ค Trainer์ ์ ๋ฌํฉ๋๋ค.
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
compute_metrics=compute_metrics,
)
ํ์ต์ ์์ํ๋ ค๋ฉด [Trainer
] ๊ฐ์ฒด์์ [~Trainer.train
]์ ํธ์ถํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
trainer.train()
ํ์ต์ด ์งํ๋๋ฉด์ ํ์ต ์์ค์ด ์ํํ๊ฒ ๊ฐ์ํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
ํ์ต์ด ์๋ฃ๋๋ฉด ๋ชจ๋ ์ฌ๋์ด ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๋๋ก [~Trainer.push_to_hub
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ธ์ ๊ณต์ ํ์ธ์:
trainer.push_to_hub()
์ถ๋ก [[inference]]
test_ds
์์ ์ํ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ ธ์ ๋ชจ๋ธ์ ํ
์คํธํฉ๋๋ค.
from PIL import Image
import requests
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)
image

๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง๋ฅผ ์ค๋นํฉ๋๋ค.
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values
[generate
]๋ฅผ ํธ์ถํ๊ณ ์์ธก์ ๋์ฝ๋ฉํฉ๋๋ค.
generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)
a drawing of a pink and blue pokemon
ํ์ธํ๋๋ ๋ชจ๋ธ์ด ๊ฝค ๊ด์ฐฎ์ ์บก์ ์ ์์ฑํ ๊ฒ ๊ฐ์ต๋๋ค!