yuankaihuo commited on
Commit
b4750f5
·
1 Parent(s): 2c6b804

Upload 2 files

Browse files
Files changed (2) hide show
  1. best_model_mb_ft.pth +3 -0
  2. grocery_predict.py +94 -0
best_model_mb_ft.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4aee2952b1e7bda1673648f7917b0743230bd72a61e5ce747fb9c1c88f0300d
3
+ size 17752814
grocery_predict.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torchvision
7
+ from torchvision import datasets, models, transforms
8
+ import os
9
+ from torch.utils.data import DataLoader
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ from timeit import default_timer as timer
13
+ from datasets import load_dataset
14
+ from datasets import load_metric
15
+ import evaluate
16
+ from train import *
17
+ from transformers import AutoImageProcessor, MobileViTV2ForImageClassification, MobileViTV2Config
18
+
19
+
20
+ def collate_fn(batch):
21
+ reviews, labels = zip(*batch)
22
+ labels = torch.Tensor(labels)
23
+ return reviews, labels
24
+
25
+
26
+ def compute_metrics(eval_pred):
27
+ metric = load_metric("accuracy")
28
+ predictions = np.argmax(eval_pred.predictions, axis=1)
29
+ return metric.compute(predictions=predictions, references=eval_pred.label_ids)
30
+
31
+
32
+
33
+
34
+ if __name__ == '__main__':
35
+ # Define path
36
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ device = torch.device("cpu")
38
+
39
+ testdir = f"data/grocery_data_small/test"
40
+ # Change to fit hardware
41
+ num_workers = 0
42
+ batch_size = 8
43
+ num_epochs = 5
44
+ learning_rate = 0.001
45
+ max_epochs_stop = 3
46
+ print_every = 1
47
+ num_classes = 8
48
+ # define image transformations
49
+ # define transforms
50
+
51
+ # Datasets from folders
52
+
53
+ data = datasets.ImageFolder(root=testdir, transform=transforms.ToTensor())
54
+
55
+
56
+ # Dataloader iterators, make sure to shuffle
57
+
58
+ dataloader = DataLoader(data, batch_size=batch_size, shuffle=True,
59
+ num_workers=num_workers, collate_fn=collate_fn)
60
+
61
+
62
+ # Define the network with pretrained models
63
+
64
+ model_checkpoint = "apple/mobilevitv2-1.0-imagenet1k-256"
65
+
66
+ processor = AutoImageProcessor.from_pretrained(model_checkpoint)
67
+ model = MobileViTV2ForImageClassification.from_pretrained(model_checkpoint, num_labels=num_classes,
68
+ ignore_mismatched_sizes=True)
69
+
70
+ #model = MobileViTV2ForImageClassification(MobileViTV2Config())
71
+ #model.classifier = nn.Linear(model.classifier.in_features, num_classes)
72
+ print(model.classifier)
73
+
74
+ model.load_state_dict(torch.load('best_model_mb_ft.pth'))
75
+ model.eval()
76
+ correct = 0
77
+ total = 0
78
+
79
+ with torch.no_grad():
80
+ for inputs, labels in tqdm(dataloader, desc="predicting"):
81
+ inputs = processor(images=inputs, return_tensors="pt")
82
+
83
+ outputs = model(**inputs)
84
+ logits = outputs.logits
85
+
86
+ predicted = logits.argmax(-1)
87
+
88
+ total += labels.size(0)
89
+ correct += (predicted == labels).sum().item()
90
+
91
+ accuracy = correct / total * 100
92
+
93
+ print(f"Predict Accuracy: {accuracy:.2f}%")
94
+