File size: 2,725 Bytes
b4750f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import os
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from timeit import default_timer as timer
from datasets import load_dataset
from datasets import load_metric
import evaluate
from train import *
from transformers import AutoImageProcessor, MobileViTV2ForImageClassification, MobileViTV2Config


def collate_fn(batch):
    reviews, labels = zip(*batch)
    labels = torch.Tensor(labels)
    return reviews, labels


def compute_metrics(eval_pred):
    metric = load_metric("accuracy")
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)




if __name__ == '__main__':
    # Define path
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

    testdir = f"data/grocery_data_small/test"
    # Change to fit hardware
    num_workers = 0
    batch_size = 8
    num_epochs = 5
    learning_rate = 0.001
    max_epochs_stop = 3
    print_every = 1
    num_classes = 8
    # define image transformations
    # define transforms

    # Datasets from folders

    data = datasets.ImageFolder(root=testdir, transform=transforms.ToTensor())


    # Dataloader iterators, make sure to shuffle

    dataloader = DataLoader(data, batch_size=batch_size, shuffle=True,
                            num_workers=num_workers, collate_fn=collate_fn)


    # Define the network with pretrained models

    model_checkpoint = "apple/mobilevitv2-1.0-imagenet1k-256"

    processor = AutoImageProcessor.from_pretrained(model_checkpoint)
    model = MobileViTV2ForImageClassification.from_pretrained(model_checkpoint, num_labels=num_classes,
                                                              ignore_mismatched_sizes=True)

    #model = MobileViTV2ForImageClassification(MobileViTV2Config())
    #model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    print(model.classifier)

    model.load_state_dict(torch.load('best_model_mb_ft.pth'))
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="predicting"):
            inputs = processor(images=inputs, return_tensors="pt")

            outputs = model(**inputs)
            logits = outputs.logits

            predicted = logits.argmax(-1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total * 100

    print(f"Predict Accuracy: {accuracy:.2f}%")