Spaces:
Sleeping
Sleeping
from __future__ import print_function, division | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
from torchvision import datasets, models, transforms | |
import os | |
from torch.utils.data import DataLoader | |
import numpy as np | |
from tqdm import tqdm | |
from timeit import default_timer as timer | |
from datasets import load_dataset | |
from datasets import load_metric | |
import evaluate | |
from train import * | |
from transformers import AutoImageProcessor, MobileViTV2ForImageClassification, MobileViTV2Config | |
def collate_fn(batch): | |
reviews, labels = zip(*batch) | |
labels = torch.Tensor(labels) | |
return reviews, labels | |
def compute_metrics(eval_pred): | |
metric = load_metric("accuracy") | |
predictions = np.argmax(eval_pred.predictions, axis=1) | |
return metric.compute(predictions=predictions, references=eval_pred.label_ids) | |
if __name__ == '__main__': | |
# Define path | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = torch.device("cpu") | |
testdir = f"data/grocery_data_small/test" | |
# Change to fit hardware | |
num_workers = 0 | |
batch_size = 8 | |
num_epochs = 5 | |
learning_rate = 0.001 | |
max_epochs_stop = 3 | |
print_every = 1 | |
num_classes = 8 | |
# define image transformations | |
# define transforms | |
# Datasets from folders | |
data = datasets.ImageFolder(root=testdir, transform=transforms.ToTensor()) | |
# Dataloader iterators, make sure to shuffle | |
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, | |
num_workers=num_workers, collate_fn=collate_fn) | |
# Define the network with pretrained models | |
model_checkpoint = "apple/mobilevitv2-1.0-imagenet1k-256" | |
processor = AutoImageProcessor.from_pretrained(model_checkpoint) | |
model = MobileViTV2ForImageClassification.from_pretrained(model_checkpoint, num_labels=num_classes, | |
ignore_mismatched_sizes=True) | |
#model = MobileViTV2ForImageClassification(MobileViTV2Config()) | |
#model.classifier = nn.Linear(model.classifier.in_features, num_classes) | |
print(model.classifier) | |
model.load_state_dict(torch.load('best_model_mb_ft.pth')) | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for inputs, labels in tqdm(dataloader, desc="predicting"): | |
inputs = processor(images=inputs, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted = logits.argmax(-1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
accuracy = correct / total * 100 | |
print(f"Predict Accuracy: {accuracy:.2f}%") | |