gitfreder's picture
Update app.py
297bef2 verified
import gradio as gr
from transformers import pipeline
import torch
from torchvision import transforms as T
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class DownSyndromeDetection(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super(DownSyndromeDetection, self).__init__()
self.conv_layer_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv_layer_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv_layer_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv_layer_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv_layer_5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.pooling_layer = nn.MaxPool2d(kernel_size=2)
self.fc_layer_1 = nn.Linear(16 * 28 * 28, 512)
self.fc_layer_2 = nn.Linear(512, 2)
def forward(self, x):
x = torch.tanh(self.conv_layer_1(x))
x = self.pooling_layer(x)
x = torch.tanh(self.conv_layer_2(x))
x = self.pooling_layer(x)
x = torch.tanh(self.conv_layer_3(x))
x = self.pooling_layer(x)
x = torch.tanh(self.conv_layer_4(x))
x = self.pooling_layer(x)
x = torch.tanh(self.conv_layer_5(x))
x = self.pooling_layer(x)
# flatten layer
x = x.view(x.size(0), -1)
x = torch.relu(self.fc_layer_1(x))
x = self.fc_layer_2(x)
x = torch.log_softmax(x, 1)
return x
def calc_result_confidence (model_output):
probs = torch.nn.functional.softmax(model_output, dim=1)
conf, classes = torch.max(probs, 1)
return conf.item(), classes.item()
def downsyndrome_gradio_inference(img_file):
classes = ['Down Syndrome', 'Healty']
infer_transform = T.Compose([
T.Resize((255, 255)),
T.ToTensor(),
])
transform_image = infer_transform(img_file.convert('RGB')).float().unsqueeze(0)
#model = pipeline(task='image-classification', model='gitfreder/down-syndrome-detection')
model = DownSyndromeDetection().from_pretrained('gitfreder/down-syndrome-detection')
conf, cls = calc_result_confidence(model(transform_image))
return {
'Predicted': classes[cls],
'Confidence Score': conf
}
iface = gr.Interface(fn=downsyndrome_gradio_inference, inputs=gr.Image(type='pil'), outputs=gr.JSON(), title="Down Syndrome Detection", description="A model interfaces that detect downsyndrom children from the photo")
iface.launch(share=True)