dhhd255's picture
Update app.py
f143c99
raw
history blame
2.16 kB
import torch
from transformers import AutoModel
import torch.nn as nn
from PIL import Image
import numpy as np
import streamlit as st
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the trained model from the Hugging Face Hub
model = AutoModel.from_pretrained('dhhd255/parkinsons_pred0.1')
# Move the model to the device
model = model.to(device)
# Add custom CSS to use the Inter font and define custom classes for healthy and parkinsons results
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
body {
font-family: 'Inter', sans-serif;
}
.healthy {
color: #007E3F;
}
.parkinsons {
color: #C30000;
}
</style>
""", unsafe_allow_html=True)
st.title("Parkinson's Disease Prediction")
uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
col1, col2 = st.columns(2)
# Load and resize the image
image_size = (224, 224)
new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
col1.image(new_image, use_column_width=True)
new_image = np.array(new_image)
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
# Move the data to the device
new_image = new_image.to(device)
# Make predictions using the trained model
with torch.no_grad():
predictions = model(new_image)
logits = predictions.last_hidden_state
logits = logits.view(logits.shape[0], -1)
num_classes=2
feature_reducer = nn.Linear(logits.shape[1], num_classes)
logits = logits.to(device)
feature_reducer = feature_reducer.to(device)
logits = feature_reducer(logits)
predicted_class = torch.argmax(logits, dim=1).item()
if(predicted_class == 0):
col2.markdown('<span class="parkinsons">Predicted class: Parkinson\'s</span>', unsafe_allow_html=True)
else:
col2.markdown('<span class="healthy">Predicted class: Healthy</span>', unsafe_allow_html=True)