File size: 2,237 Bytes
ceb3873
 
 
 
 
 
 
 
 
 
279f6a3
392da69
279f6a3
392da69
279f6a3
392da69
279f6a3
392da69
279f6a3
 
 
 
 
 
392da69
279f6a3
392da69
279f6a3
392da69
279f6a3
392da69
279f6a3
392da69
279f6a3
392da69
279f6a3
b95d011
 
 
279f6a3
392da69
b95d011
 
 
279f6a3
 
b95d011
 
392da69
b95d011
279f6a3
 
 
392da69
279f6a3
392da69
279f6a3
392da69
279f6a3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
---
tags:
- image-classification
- pytorch
library_name: transformers
datasets:
- garythung/trashnet
---


# Trash Image Classification using Vision Transformer (ViT)

This repository contains an implementation of an image classification model using a pre-trained Vision Transformer (ViT) model from Hugging Face. The model is fine-tuned to classify images into six categories: cardboard, glass, metal, paper, plastic, and trash.

## Dataset

The dataset consists of images from six categories from [`garythung/trashnet`](https://huggingface.co/datasets/garythung/trashnet) with the following distribution:

- Cardboard: 806 images
- Glass: 1002 images
- Metal: 820 images
- Paper: 1188 images
- Plastic: 964 images
- Trash: 274 images

## Model

We utilize the pre-trained Vision Transformer model [`google/vit-base-patch16-224-in21k`](https://huggingface.co/google/vit-base-patch16-224-in21k) from Hugging Face for image classification. The model is fine-tuned on the dataset to achieve optimal performance.

The trained model is accessible on Hugging Face Hub at: [tribber93/my-trash-classification](https://huggingface.co/tribber93/my-trash-classification)

## Usage

To use the model for inference, follow these steps:

```python
import torch
import requests
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor

url = 'https://cdn.grid.id/crop/0x0:0x0/700x465/photo/grid/original/127308_kaleng-bekas.jpg'
image = Image.open(requests.get(url, stream=True).raw)

model_name = "tribber93/my-trash-classification"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
inputs = processor(image, return_tensors="pt")

outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
print("Predicted class:", model.config.id2label[predictions.item()])
```

## Results

After training, the model achieved the following performance:

| Epoch | Training Loss | Validation Loss | Accuracy |
|-------|---------------|-----------------|----------|
| 1     | 3.3200        | 0.7011          | 86.25%   |
| 2     | 1.6611        | 0.4298          | 91.49%   |
| 3     | 1.4353        | 0.3563          | 94.26%   |