from __future__ import print_function, division import torch import torch.nn as nn import torch.optim as optim import torchvision from torchvision import datasets, models, transforms import os from torch.utils.data import DataLoader import numpy as np from tqdm import tqdm from timeit import default_timer as timer from datasets import load_dataset from datasets import load_metric import evaluate from train import * from transformers import AutoImageProcessor, MobileViTV2ForImageClassification, MobileViTV2Config def collate_fn(batch): reviews, labels = zip(*batch) labels = torch.Tensor(labels) return reviews, labels def compute_metrics(eval_pred): metric = load_metric("accuracy") predictions = np.argmax(eval_pred.predictions, axis=1) return metric.compute(predictions=predictions, references=eval_pred.label_ids) if __name__ == '__main__': # Define path # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cpu") testdir = f"data/grocery_data_small/test" # Change to fit hardware num_workers = 0 batch_size = 8 num_epochs = 5 learning_rate = 0.001 max_epochs_stop = 3 print_every = 1 num_classes = 8 # define image transformations # define transforms # Datasets from folders data = datasets.ImageFolder(root=testdir, transform=transforms.ToTensor()) # Dataloader iterators, make sure to shuffle dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn) # Define the network with pretrained models model_checkpoint = "apple/mobilevitv2-1.0-imagenet1k-256" processor = AutoImageProcessor.from_pretrained(model_checkpoint) model = MobileViTV2ForImageClassification.from_pretrained(model_checkpoint, num_labels=num_classes, ignore_mismatched_sizes=True) #model = MobileViTV2ForImageClassification(MobileViTV2Config()) #model.classifier = nn.Linear(model.classifier.in_features, num_classes) print(model.classifier) model.load_state_dict(torch.load('best_model_mb_ft.pth')) model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in tqdm(dataloader, desc="predicting"): inputs = processor(images=inputs, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted = logits.argmax(-1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total * 100 print(f"Predict Accuracy: {accuracy:.2f}%")