Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
from sklearn.preprocessing import StandardScaler | |
import joblib | |
import os | |
# Disable GPU for TensorFlow to avoid CUDA conflicts | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
# Set PyTorch device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load trained ViT model (PyTorch) | |
vit_model = models.vit_b_16(weights="DEFAULT") # Fixed deprecated 'pretrained' | |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification | |
# Load ViT model weights (if available) | |
vit_model_path = "vit_bc.pth" | |
if os.path.exists(vit_model_path): | |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device)) | |
vit_model.to(device) | |
vit_model.eval() | |
# Define image transformations for ViT | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Class labels | |
class_names = ["Benign", "Malignant"] | |
# Load trained Neural Network model (TensorFlow/Keras) | |
nn_model_path = "my_NN_BC_model.keras" | |
nn_model = tf.keras.models.load_model(nn_model_path) | |
if os.path.exists(nn_model_path): | |
try: | |
nn_model = tf.keras.models.load_model(nn_model_path) | |
except Exception as e: | |
print(f"Error loading NN model: {e}") | |
# Load scaler for feature normalization | |
scaler_path = "nn_bc_scaler.pkl" | |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None | |
# Feature names | |
feature_names = [ | |
"Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness", | |
"Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension", | |
"SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness", | |
"SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension", | |
"Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness", | |
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension" | |
] | |
# Example inputs | |
benign_example = [9.504,12.44,60.34,273.9,0.1024,0.06492,0.02956,0.02076,0.1815,0.06905,0.2773,0.9768, | |
1.909,15.7,0.009606,0.01432,0.01985,0.01421,0.02027,0.002968,10.23,15.66,65.13,314.9, | |
0.1324,0.1148,0.08867,0.06227,0.245,0.07773] | |
malignant_example = [11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,0.4956,1.156, | |
3.445,27.23,0.00911,0.07458,0.05661,0.01867,0.05963,0.009208,14.91,26.5,98.87,567.7, | |
0.2098,0.8663,0.6869,0.2575,0.6638,0.173] | |
def classify(model_choice, image=None, *features): | |
"""Classify using ViT (image) or NN (features).""" | |
if model_choice == "ViT": | |
if image is None: | |
return "β Please upload an image for ViT classification." | |
image = image.convert("RGB") | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = vit_model(input_tensor) | |
predicted_class = torch.argmax(output, dim=1).item() | |
return class_names[predicted_class] | |
elif model_choice == "Neural Network": | |
if any(f is None for f in features): | |
return "β Please enter all 30 numerical features." | |
input_data = np.array(features).reshape(1, -1) | |
input_data_std = scaler.transform(input_data) if scaler else input_data | |
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]] | |
predicted_class = np.argmax(prediction) | |
return class_names[predicted_class] | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## π©Ί Breast Cancer Classification Model") | |
gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.") | |
with gr.Row(): | |
model_selector = gr.Radio(["ViT", "Neural Network"], label="π¬ Choose Model", value="ViT") | |
image_input = gr.Image(type="pil", label="π· Upload Image (for ViT)", visible=True) | |
feature_inputs = [gr.Number(label=feature) for feature in feature_names] | |
# Organizing feature inputs into rows of 3 columns | |
with gr.Row(): | |
with gr.Column(): | |
for i in range(0, len(feature_inputs), 3): | |
gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))]) | |
# Example buttons | |
def fill_example(example): | |
"""Pre-fills example inputs.""" | |
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))} | |
with gr.Row(): | |
example_btn_1 = gr.Button("π΄ Malignant Example") | |
example_btn_2 = gr.Button("π΅ Benign Example") | |
output_text = gr.Textbox(label="π Model Prediction", interactive=False) | |
def extract_features_from_file(file): | |
"""Reads a text file and extracts numerical features.""" | |
if file is None: | |
return "β Please upload a valid feature file." | |
try: | |
# Read and process file contents | |
content = file.read().decode("utf-8").strip() | |
values = [float(x) for x in content.replace(",", " ").split()] | |
# Check if we have exactly 30 features | |
if len(values) != 30: | |
return "β The file must contain exactly 30 numerical values." | |
return {feature_inputs[i]: values[i] for i in range(30)} | |
except Exception as e: | |
return f"β Error processing file: {e}" | |
# Add file upload component | |
file_input = gr.File(label="π Upload Feature File (for NN)", type="binary", visible=False) | |
# Update UI logic to show file input for NN model | |
def toggle_inputs(choice): | |
image_visibility = choice == "ViT" | |
feature_visibility = choice == "Neural Network" | |
file_visibility = choice == "Neural Network" | |
return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) + [gr.update(visible=file_visibility)] | |
model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs, file_input]) | |
# Process uploaded file and populate feature fields | |
file_input.change(extract_features_from_file, file_input, feature_inputs) | |
# Toggle input fields based on model selection | |
"""Toggle visibility of inputs based on model selection.""" | |
def toggle_inputs(choice): | |
image_visibility = choice == "ViT" | |
feature_visibility = choice == "Neural Network" | |
return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) | |
model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs]) | |
example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs) | |
example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs) | |
classify_button = gr.Button("π Classify") | |
classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text) | |
demo.launch() |