bodypartxr / app.py
Jason Adrian
bodypartxr classifier
d360108
raw
history blame
2.14 kB
import gradio as gr
import torch
from torchvision.transforms import transforms
import numpy as np
from resnet18 import ResNet18
model = ResNet18(1, 5)
checkpoint = torch.load('C:\jason\semester 8\Magang\Hugging-face-bodypartxr\bodypartxr\acc=0.94.ckpt')
# The state dict will contains net.layer_name
# Our model doesn't contains `net.` so we have to rename it
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
if 'net.' in key:
state_dict[key.replace('net.', '')] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.eval()
class_names = ['abdominal', 'adult', 'others', 'pediatric', 'spine']
class_names.sort()
transformation_pipeline = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=1),
transforms.CenterCrop((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.50807575], std=[0.20823])
])
def preprocess_image(image: np.ndarray):
"""Preprocess the input image.
Note that the input image is in RGB mode.
Parameters
----------
image: np.ndarray
Input image from callback.
"""
image = transformation_pipeline(image)
image = torch.unsqueeze(image, 0)
return image
def image_classifier(inp):
"""Image Classifier Function.
Parameters
----------
inp: Optional[np.ndarray] = None
Input image from callback
Returns
-------
Dict
A dictionary class names and its probability
"""
# If input not valid, return dummy data or raise error
if inp is None:
return {'cat': 0.3, 'dog': 0.7}
# preprocess
image = preprocess_image(inp)
image = image.to(dtype=torch.float32)
# inference
result = model(image)
# postprocess
result = torch.nn.functional.softmax(result, dim=1) # apply softmax
result = result[0].detach().numpy().tolist() # take the first batch
labeled_result = {name:score for name, score in zip(class_names, result)}
return labeled_result
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
demo.launch()