GabrielML's picture
Init gradio repo
234009d
raw
history blame
2.99 kB
from torchvision import transforms
from torchvision import models
from efficientnet_pytorch import EfficientNet
import torch
from CustomModels import DinoVisionClassifier
classes = {0: 'Glas', 1: 'Organic', 2: 'Papier', 3: 'Restmüll', 4: 'Wertstoff'}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose(
[transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
transform_dinov2 = transforms.Compose(
[ transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
def load_specific_model(model_name):
current_model = None
if model_name == "EfficientNet-B3":
current_model = EfficientNet.from_pretrained("efficientnet-b3", num_classes=len(classes.keys()))
current_model.load_state_dict(torch.load("src/models/eff_b3_model.pt", map_location="cpu"))
elif model_name == "EfficientNet-B4":
current_model = EfficientNet.from_pretrained("efficientnet-b4", num_classes=len(classes.keys()))
current_model.load_state_dict(torch.load("src/models/eff_b4.pt", map_location="cpu"))
elif model_name == "vgg19":
current_model = models.vgg19()
in_features = current_model.classifier[0].in_features
current_model.classifier = torch.nn.Linear(in_features, len(classes.keys()))
current_model.load_state_dict(torch.load("src/models/vgg19.pt", map_location="cpu"))
elif model_name == "resnet50":
current_model = models.resnet50()
in_features = current_model.fc.in_features
current_model.fc = torch.nn.Linear(in_features, len(classes.keys()))
current_model.load_state_dict(torch.load("src/models/resnet50.pt", map_location="cpu"))
elif model_name == "dinov2_vits14":
current_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vits14")
current_model = DinoVisionClassifier(current_model, num_classes=len(classes.keys()))
current_model.load_state_dict(torch.load("src/models/dinov2_vits14_0.054_98.00.pth", map_location="cpu"))
print(f"Loaded model {model_name}")
return current_model.eval().to(device)
def inference(model, inp):
model.eval()
inp = transform(inp) if model.__class__.__name__ != "DinoVisionClassifier" else transform_dinov2(inp)
inp = inp.unsqueeze(0).to(device)
if torch.cuda.is_available():
with torch.no_grad(), torch.cuda.amp.autocast():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy()
else:
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy()
confidences = {classes[i]: float(prediction[i]) for i in range(len(classes.keys()))}
return confidences