Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
from sklearn.preprocessing import StandardScaler | |
import joblib | |
import os | |
# Disable GPU for TensorFlow to avoid CUDA conflicts | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
# Set PyTorch device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load trained ViT model (PyTorch) | |
vit_model = models.vit_b_16(weights="DEFAULT") # Fixed deprecated 'pretrained' | |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification | |
# Load ViT model weights (if available) | |
vit_model_path = "vit_bc.pth" | |
if os.path.exists(vit_model_path): | |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device)) | |
vit_model.to(device) | |
vit_model.eval() | |
# Define image transformations for ViT | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Class labels | |
class_names = ["Benign", "Malignant"] | |
# Load trained Neural Network model (TensorFlow/Keras) | |
nn_model_path = "my_NN_BC_model.keras" | |
nn_model = None | |
if os.path.exists(nn_model_path): | |
try: | |
nn_model = tf.keras.models.load_model(nn_model_path) | |
except Exception as e: | |
print(f"Error loading NN model: {e}") | |
# Load scaler for feature normalization | |
scaler_path = "nn_bc_scaler.pkl" | |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None | |
# Feature names | |
feature_names = [ | |
"Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness", | |
"Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension", | |
"SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness", | |
"SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension", | |
"Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness", | |
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension" | |
] | |
# Example inputs | |
benign_example = [13.54, 14.36, 87.46, 566.3, 0.09779, 0.08129, 0.06664, 0.04781, 0.1885, 0.05766, | |
0.2699, 0.7886, 2.058, 23.56, 0.008462, 0.0146, 0.02387, 0.01315, 0.0198, 0.0023, | |
15.11, 19.26, 99.7, 711.2, 0.144, 0.1773, 0.239, 0.1288, 0.2977, 0.07259] | |
malignant_example = [17.99, 10.38, 122.8, 1001.0, 0.1184, 0.2776, 0.3001, 0.1471, 0.2419, 0.07871, | |
1.095, 0.9053, 8.589, 153.4, 0.006399, 0.04904, 0.05373, 0.01587, 0.03003, 0.006193, | |
25.38, 17.33, 184.6, 2019.0, 0.1622, 0.6656, 0.7119, 0.2654, 0.4601, 0.1189] | |
def classify(model_choice, image=None, *features): | |
"""Classify using ViT (image) or NN (features).""" | |
if model_choice == "ViT": | |
if image is None: | |
return "β Please upload an image for ViT classification." | |
image = image.convert("RGB") | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = vit_model(input_tensor) | |
predicted_class = torch.argmax(output, dim=1).item() | |
return f"π **Prediction:** {class_names[predicted_class]}" | |
elif model_choice == "Neural Network": | |
if any(f is None for f in features): | |
return "β Please enter all 30 numerical features." | |
input_data = np.array(features).reshape(1, -1) | |
input_data_std = scaler.transform(input_data) if scaler else input_data | |
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]] | |
predicted_class = np.argmax(prediction) | |
return f"π **Prediction:** {class_names[predicted_class]}" | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## π©Ί Breast Cancer Classification Model") | |
gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.") | |
with gr.Row(): | |
model_selector = gr.Radio(["ViT", "Neural Network"], label="π¬ Choose Model", value="ViT") | |
image_input = gr.Image(type="pil", label="π· Upload Image (for ViT)", visible=True) | |
feature_inputs = [gr.Number(label=feature) for feature in feature_names] | |
# Organizing feature inputs into rows of 3 columns | |
with gr.Row(): | |
with gr.Column(): | |
for i in range(0, len(feature_inputs), 3): | |
gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))]) | |
# Example buttons | |
def fill_example(example): | |
"""Pre-fills example inputs.""" | |
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))} | |
with gr.Row(): | |
example_btn_1 = gr.Button("π΅ Benign Example") | |
example_btn_2 = gr.Button("π΄ Malignant Example") | |
output_text = gr.Textbox(label="π Model Prediction", interactive=False) | |
# Toggle input fields based on model selection | |
"""Toggle visibility of inputs based on model selection.""" | |
def toggle_inputs(choice): | |
image_visibility = choice == "ViT" | |
feature_visibility = choice == "Neural Network" | |
return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) | |
model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs]) | |
example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs) | |
example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs) | |
classify_button = gr.Button("π Classify") | |
classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text) | |
demo.launch() |