File size: 2,237 Bytes
ceb3873 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 b95d011 279f6a3 392da69 b95d011 279f6a3 b95d011 392da69 b95d011 279f6a3 392da69 279f6a3 392da69 279f6a3 392da69 279f6a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
---
tags:
- image-classification
- pytorch
library_name: transformers
datasets:
- garythung/trashnet
---
# Trash Image Classification using Vision Transformer (ViT)
This repository contains an implementation of an image classification model using a pre-trained Vision Transformer (ViT) model from Hugging Face. The model is fine-tuned to classify images into six categories: cardboard, glass, metal, paper, plastic, and trash.
## Dataset
The dataset consists of images from six categories from [`garythung/trashnet`](https://huggingface.co/datasets/garythung/trashnet) with the following distribution:
- Cardboard: 806 images
- Glass: 1002 images
- Metal: 820 images
- Paper: 1188 images
- Plastic: 964 images
- Trash: 274 images
## Model
We utilize the pre-trained Vision Transformer model [`google/vit-base-patch16-224-in21k`](https://huggingface.co/google/vit-base-patch16-224-in21k) from Hugging Face for image classification. The model is fine-tuned on the dataset to achieve optimal performance.
The trained model is accessible on Hugging Face Hub at: [tribber93/my-trash-classification](https://huggingface.co/tribber93/my-trash-classification)
## Usage
To use the model for inference, follow these steps:
```python
import torch
import requests
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
url = 'https://cdn.grid.id/crop/0x0:0x0/700x465/photo/grid/original/127308_kaleng-bekas.jpg'
image = Image.open(requests.get(url, stream=True).raw)
model_name = "tribber93/my-trash-classification"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
inputs = processor(image, return_tensors="pt")
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
print("Predicted class:", model.config.id2label[predictions.item()])
```
## Results
After training, the model achieved the following performance:
| Epoch | Training Loss | Validation Loss | Accuracy |
|-------|---------------|-----------------|----------|
| 1 | 3.3200 | 0.7011 | 86.25% |
| 2 | 1.6611 | 0.4298 | 91.49% |
| 3 | 1.4353 | 0.3563 | 94.26% | |