File size: 1,337 Bytes
61337ab 313d602 61337ab 313d602 61337ab 313d602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
from torchvision import models, transforms
from PIL import Image
import json
# Load model
class CustomResNet:
def __init__(self, model_path, num_classes):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = models.resnet152(pretrained=False)
self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
# Preprocessing
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(self, image_bytes):
# Load and preprocess image
image = Image.open(image_bytes).convert("RGB")
tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(tensor)
_, predicted = torch.max(outputs, 1)
return predicted.item()
# API function
def load_model():
with open("config.json", "r") as f:
config = json.load(f)
return CustomResNet("trained_model.pth", config["num_labels"])
|