|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
from PIL import Image |
|
import json |
|
|
|
|
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
model.eval() |
|
|
|
|
|
model_path = 'path_to_your_model_file.pth' |
|
try: |
|
state_dict = torch.load(model_path, map_location=torch.device('cpu')) |
|
model.load_state_dict(state_dict) |
|
except RuntimeError as e: |
|
print("Error loading state_dict:", e) |
|
print("Ensure that the saved model architecture matches ResNet50.") |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225], |
|
), |
|
]) |
|
|
|
|
|
with open("imagenet_classes.json") as f: |
|
labels = json.load(f) |
|
|
|
|
|
def predict(image_path): |
|
|
|
input_image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
input_tensor = preprocess(input_image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
input_batch = input_batch.to('cuda') |
|
model.to('cuda') |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_batch) |
|
|
|
|
|
_, predicted_idx = torch.max(output, 1) |
|
predicted_class = labels[str(predicted_idx.item())] |
|
|
|
return predicted_class |
|
|
|
|
|
image_path = 'path_to_your_image.jpg' |
|
predicted_class = predict(image_path) |
|
print(f"Predicted class: {predicted_class}") |
|
|