EfficientNet Parkinson's Prediction Model πŸ€—

This repository contains the Hugging Face EfficientNet model for predicting Parkinson's disease using patient drawings with an accuracy of around 83%. Made w/ EfficientNet and Torch.

Overview

Parkinson's disease is a progressive nervous system disorder that affects movement. Symptoms start gradually, sometimes starting with a barely noticeable tremor in just one hand. Tremors are common, but the disorder also commonly causes stiffness or slowing of movement.

My model uses the EfficientNet architecture to predict the likelihood of Parkinson's disease in patients by analysing their drawings. Feel free to open a pull request and contribute if you want to.

Dataset

The dataset used to train this model was provided by Kaggle.

Usage

import torch
from transformers import AutoModel
from torch import nn
from PIL import Image
import numpy as np

# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the trained model
model = AutoModel.from_pretrained('/content/final')

# Move the model to the device
model = model.to(device)

# Load and resize new image(s)
image_size = (224, 224)
new_image = Image.open('/content/health.png').convert('RGB').resize(image_size)
new_image = np.array(new_image)
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)

# Move the data to the device
new_image = new_image.to(device)

# Make predictions using the trained model
with torch.no_grad():
    predictions = model(new_image)
    logits = predictions.last_hidden_state
    logits = logits.view(logits.shape[0], -1)
    num_classes=2
    feature_reducer = nn.Linear(logits.shape[1], num_classes)

    logits = logits.to(device)
    feature_reducer = feature_reducer.to(device)

    logits = feature_reducer(logits)
    predicted_class = torch.argmax(logits, dim=1).item()
    confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
    if(predicted_class == 0):
        print(f'Predicted class: Parkinson\'s with confidence {confidence:.2f}')
    else:
        print(f'Predicted class: Healthy with confidence {confidence:.2f}')

Downloads last month
12
Inference API
Inference API (serverless) has been turned off for this model.

Spaces using dhhd255/EfficientNet_ParkinsonsPred 3