File size: 7,185 Bytes
1af1e21
 
5c48a95
1af1e21
 
 
 
89f5ea8
 
5c48a95
1af1e21
e34e966
 
 
 
1af1e21
 
5c48a95
e34e966
5c48a95
fe2a725
e34e966
db0a0ae
5c48a95
 
 
43d2d34
1af1e21
e34e966
5c48a95
 
 
 
 
1af1e21
 
 
 
5c48a95
db0a0ae
488254a
 
 
e34e966
 
 
 
 
5c48a95
 
db0a0ae
5c48a95
 
 
1af1e21
 
 
 
 
 
a5498d4
1af1e21
 
b2fe1f7
084cbbd
 
 
b2fe1f7
084cbbd
 
 
b2fe1f7
 
 
1af1e21
 
84c111e
89f5ea8
5c48a95
1af1e21
 
5c48a95
 
1af1e21
488254a
1af1e21
b2fe1f7
 
84c111e
1af1e21
 
5c48a95
b2fe1f7
1af1e21
b2fe1f7
488254a
1af1e21
a3a28f6
 
84c111e
 
a3a28f6
84c111e
 
 
 
a3a28f6
84c111e
a3a28f6
e34e966
a3a28f6
e34e966
 
 
a3a28f6
84c111e
a3a28f6
db0a0ae
a3a28f6
 
84c111e
d89f903
 
84c111e
 
 
f757d50
f97e9dc
 
 
f757d50
f97e9dc
 
 
 
f757d50
f97e9dc
 
 
f757d50
f97e9dc
f757d50
f97e9dc
 
f757d50
5a2f8c3
 
f757d50
5a2f8c3
 
 
 
 
 
f757d50
5a2f8c3
f757d50
5a2f8c3
 
f757d50
 
e34e966
fe569d8
e34bafc
 
 
 
fe569d8
84c111e
fe569d8
84c111e
 
 
 
 
a3a28f6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
from sklearn.preprocessing import StandardScaler
import joblib
import os

# Disable GPU for TensorFlow to avoid CUDA conflicts
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Set PyTorch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load trained ViT model (PyTorch)
vit_model = models.vit_b_16(weights="DEFAULT")  # Fixed deprecated 'pretrained'
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2)  # Binary classification

# Load ViT model weights (if available)
vit_model_path = "vit_bc.pth"
if os.path.exists(vit_model_path):
    vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
vit_model.to(device)
vit_model.eval()

# Define image transformations for ViT
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Class labels
class_names = ["Benign", "Malignant"]

# Load trained Neural Network model (TensorFlow/Keras)
nn_model_path = "my_NN_BC_model.keras"

nn_model = tf.keras.models.load_model(nn_model_path)

if os.path.exists(nn_model_path):
    try:
        nn_model = tf.keras.models.load_model(nn_model_path)
    except Exception as e:
        print(f"Error loading NN model: {e}")

# Load scaler for feature normalization
scaler_path = "nn_bc_scaler.pkl"
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None

# Feature names
feature_names = [
    "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
    "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
    "SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness",
    "SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension",
    "Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness",
    "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
]

# Example inputs
benign_example = [9.504,12.44,60.34,273.9,0.1024,0.06492,0.02956,0.02076,0.1815,0.06905,0.2773,0.9768,
                  1.909,15.7,0.009606,0.01432,0.01985,0.01421,0.02027,0.002968,10.23,15.66,65.13,314.9,
                  0.1324,0.1148,0.08867,0.06227,0.245,0.07773]

malignant_example = [11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,0.4956,1.156,
                     3.445,27.23,0.00911,0.07458,0.05661,0.01867,0.05963,0.009208,14.91,26.5,98.87,567.7,
                     0.2098,0.8663,0.6869,0.2575,0.6638,0.173]

def classify(model_choice, image=None, *features):
    """Classify using ViT (image) or NN (features)."""
    if model_choice == "ViT":
        if image is None:
            return "❌ Please upload an image for ViT classification."
        image = image.convert("RGB")
        input_tensor = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            output = vit_model(input_tensor)
            predicted_class = torch.argmax(output, dim=1).item()

        return class_names[predicted_class]

    elif model_choice == "Neural Network":
        if any(f is None for f in features):
            return "❌ Please enter all 30 numerical features."

        input_data = np.array(features).reshape(1, -1)
        input_data_std = scaler.transform(input_data) if scaler else input_data
        prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
        predicted_class = np.argmax(prediction)

        return class_names[predicted_class]

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🩺 Breast Cancer Classification Model")
    gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.")

    with gr.Row():
        model_selector = gr.Radio(["ViT", "Neural Network"], label="πŸ”¬ Choose Model", value="ViT")

    image_input = gr.Image(type="pil", label="πŸ“· Upload Image (for ViT)", visible=True)

    feature_inputs = [gr.Number(label=feature) for feature in feature_names]

    # Organizing feature inputs into rows of 3 columns
    with gr.Row():
        with gr.Column():
            for i in range(0, len(feature_inputs), 3):
                gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))])

    # Example buttons
    def fill_example(example):
        """Pre-fills example inputs."""
        return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}

    with gr.Row():
        example_btn_1 = gr.Button("πŸ”΄ Malignant Example")
        example_btn_2 = gr.Button("πŸ”΅ Benign Example")

    output_text = gr.Textbox(label="πŸ” Model Prediction", interactive=False)

    def extract_features_from_file(file):
        """Reads a text file and extracts numerical features."""
        if file is None:
            return "❌ Please upload a valid feature file."
    
        try:
            # Read and process file contents
            content = file.read().decode("utf-8").strip()
            values = [float(x) for x in content.replace(",", " ").split()]
        
            # Check if we have exactly 30 features
            if len(values) != 30:
                return "❌ The file must contain exactly 30 numerical values."

            return {feature_inputs[i]: values[i] for i in range(30)}

        except Exception as e:
            return f"❌ Error processing file: {e}"

    # Add file upload component
    file_input = gr.File(label="πŸ“‚ Upload Feature File (for NN)", type="binary", visible=False)

    # Update UI logic to show file input for NN model
    def toggle_inputs(choice):
        image_visibility = choice == "ViT"
        feature_visibility = choice == "Neural Network"
        file_visibility = choice == "Neural Network"
        return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) + [gr.update(visible=file_visibility)]

    model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs, file_input])

    # Process uploaded file and populate feature fields
    file_input.change(extract_features_from_file, file_input, feature_inputs)


    # Toggle input fields based on model selection
    """Toggle visibility of inputs based on model selection."""
    def toggle_inputs(choice):
        image_visibility = choice == "ViT"
        feature_visibility = choice == "Neural Network"
        return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs)

    model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs])

    example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs)
    example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs)

    classify_button = gr.Button("πŸš€ Classify")
    classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text)

demo.launch()