|
--- |
|
library_name: transformers |
|
base_model: |
|
- google/vit-base-patch16-224 |
|
--- |
|
|
|
# Model Card for Pokémon Type Classification |
|
|
|
This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types. |
|
|
|
It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset. |
|
|
|
## Model Details |
|
|
|
- **Developer:** Xianglu (Steven) Zhu |
|
- **Purpose:** Pokémon type classification |
|
- **Model Type:** Vision Transformer (ViT) for image classification |
|
|
|
## Getting Started |
|
|
|
Here’s how you can use the model for classification: |
|
|
|
```python |
|
import torch |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
|
|
# Load the pretrained model and feature extractor |
|
hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model") |
|
hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model") |
|
|
|
# Define preprocessing transformations |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std) |
|
]) |
|
|
|
# Mapping of labels to indices and vice versa |
|
labels_dict = { |
|
'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6, |
|
'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12, |
|
'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17 |
|
} |
|
idx_to_label = {v: k for k, v in labels_dict.items()} |
|
|
|
# Load and preprocess the image |
|
image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg" |
|
image = Image.open(image_path).convert("RGB") |
|
input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224) |
|
|
|
# Make a prediction |
|
hf_model.eval() |
|
with torch.no_grad(): |
|
outputs = hf_model(input_tensor) |
|
logits = outputs.logits |
|
predicted_class_idx = torch.argmax(logits, dim=1).item() |
|
|
|
predicted_class = idx_to_label[predicted_class_idx] |
|
print("Predicted Pokémon type:", predicted_class) |
|
``` |
|
|
|
|
|
|
|
|