Vit_BCC_APP / app.py
andromeda01111's picture
Update app.py
e34bafc verified
raw
history blame
5.77 kB
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
from sklearn.preprocessing import StandardScaler
import joblib
import os
# Disable GPU for TensorFlow to avoid CUDA conflicts
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Set PyTorch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load trained ViT model (PyTorch)
vit_model = models.vit_b_16(weights="DEFAULT") # Fixed deprecated 'pretrained'
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
# Load ViT model weights (if available)
vit_model_path = "vit_bc.pth"
if os.path.exists(vit_model_path):
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
vit_model.to(device)
vit_model.eval()
# Define image transformations for ViT
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Class labels
class_names = ["Benign", "Malignant"]
# Load trained Neural Network model (TensorFlow/Keras)
nn_model_path = "my_NN_BC_model.keras"
nn_model = None
if os.path.exists(nn_model_path):
try:
nn_model = tf.keras.models.load_model(nn_model_path)
except Exception as e:
print(f"Error loading NN model: {e}")
# Load scaler for feature normalization
scaler_path = "nn_bc_scaler.pkl"
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
# Feature names
feature_names = [
"Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
"Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
"SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness",
"SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension",
"Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness",
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
]
# Example inputs
benign_example = [13.54, 14.36, 87.46, 566.3, 0.09779, 0.08129, 0.06664, 0.04781, 0.1885, 0.05766,
0.2699, 0.7886, 2.058, 23.56, 0.008462, 0.0146, 0.02387, 0.01315, 0.0198, 0.0023,
15.11, 19.26, 99.7, 711.2, 0.144, 0.1773, 0.239, 0.1288, 0.2977, 0.07259]
malignant_example = [17.99, 10.38, 122.8, 1001.0, 0.1184, 0.2776, 0.3001, 0.1471, 0.2419, 0.07871,
1.095, 0.9053, 8.589, 153.4, 0.006399, 0.04904, 0.05373, 0.01587, 0.03003, 0.006193,
25.38, 17.33, 184.6, 2019.0, 0.1622, 0.6656, 0.7119, 0.2654, 0.4601, 0.1189]
def classify(model_choice, image=None, *features):
"""Classify using ViT (image) or NN (features)."""
if model_choice == "ViT":
if image is None:
return "❌ Please upload an image for ViT classification."
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = vit_model(input_tensor)
predicted_class = torch.argmax(output, dim=1).item()
return f"πŸ” **Prediction:** {class_names[predicted_class]}"
elif model_choice == "Neural Network":
if any(f is None for f in features):
return "❌ Please enter all 30 numerical features."
input_data = np.array(features).reshape(1, -1)
input_data_std = scaler.transform(input_data) if scaler else input_data
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
predicted_class = np.argmax(prediction)
return f"πŸ” **Prediction:** {class_names[predicted_class]}"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🩺 Breast Cancer Classification Model")
gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.")
with gr.Row():
model_selector = gr.Radio(["ViT", "Neural Network"], label="πŸ”¬ Choose Model", value="ViT")
image_input = gr.Image(type="pil", label="πŸ“· Upload Image (for ViT)", visible=True)
feature_inputs = [gr.Number(label=feature) for feature in feature_names]
# Organizing feature inputs into rows of 3 columns
with gr.Row():
with gr.Column():
for i in range(0, len(feature_inputs), 3):
gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))])
# Example buttons
def fill_example(example):
"""Pre-fills example inputs."""
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
with gr.Row():
example_btn_1 = gr.Button("πŸ”΅ Benign Example")
example_btn_2 = gr.Button("πŸ”΄ Malignant Example")
output_text = gr.Textbox(label="πŸ” Model Prediction", interactive=False)
# Toggle input fields based on model selection
"""Toggle visibility of inputs based on model selection."""
def toggle_inputs(choice):
image_visibility = choice == "ViT"
feature_visibility = choice == "Neural Network"
return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs)
model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs])
example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs)
example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs)
classify_button = gr.Button("πŸš€ Classify")
classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text)
demo.launch()