Spaces:
Running
Running
import torch | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
from sklearn.preprocessing import StandardScaler | |
import joblib | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load ViT model from Hugging Face | |
model_path = "andromeda01111/ViT_BCC/model.pkl" | |
vit_model = torch.load(model_path, map_location=torch.device("cpu")) | |
vit_model.eval() | |
# Load Neural Network model from Hugging Face | |
nn_model = tf.keras.models.load_model("andromeda01111/NN_BC") | |
# Load scaler (ensure it's uploaded in the Hugging Face repo) | |
scaler = joblib.load("scaler.pkl") | |
# Class labels | |
class_names = ["Benign", "Malignant"] | |
# Define feature names | |
feature_names = [ | |
"Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness", | |
"Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension", | |
"SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness", | |
"SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension", | |
"Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness", | |
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension" | |
] | |
def classify(model_choice, image=None, *features): | |
if model_choice == "ViT": | |
if image is None: | |
return "Please upload an image." | |
image = image.convert("RGB") | |
inputs = vit_processor(images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = vit_model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
return class_names[predicted_class] | |
elif model_choice == "Neural Network": | |
if any(f is None for f in features): | |
return "Please enter all 30 numerical features." | |
input_data = np.array(features).reshape(1, -1) | |
input_data_std = scaler.transform(input_data) | |
prediction = nn_model.predict(input_data_std) | |
predicted_class = np.argmax(prediction) | |
return class_names[predicted_class] | |
# Gradio UI | |
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model") | |
image_input = gr.Image(type="pil", label="Upload Mammogram Image") | |
feature_inputs = [gr.Number(label=feature) for feature in feature_names] | |
iface = gr.Interface( | |
fn=classify, | |
inputs=[model_selector, image_input] + feature_inputs, | |
outputs="text", | |
title="Breast Cancer Classification", | |
description="Choose ViT (image-based) or Neural Network (feature-based) classification." | |
) | |
iface.launch() |