flowers_custom / model.py
megiddo's picture
Update model.py
313d602 verified
import torch
from torchvision import models, transforms
from PIL import Image
import json
# Load model
class CustomResNet:
def __init__(self, model_path, num_classes):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = models.resnet152(pretrained=False)
self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
# Preprocessing
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(self, image_bytes):
# Load and preprocess image
image = Image.open(image_bytes).convert("RGB")
tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(tensor)
_, predicted = torch.max(outputs, 1)
return predicted.item()
# API function
def load_model():
with open("config.json", "r") as f:
config = json.load(f)
return CustomResNet("trained_model.pth", config["num_labels"])