andromeda01111 commited on
Commit
89f5ea8
·
verified ·
1 Parent(s): 9d893e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -45
app.py CHANGED
@@ -5,41 +5,27 @@ import gradio as gr
5
  import numpy as np
6
  import tensorflow as tf
7
  from PIL import Image
8
- from sklearn.preprocessing import StandardScaler # Required for feature scaling
9
- import joblib # To load the scaler
 
10
 
11
- # Set device for ViT model (PyTorch)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # Load trained ViT model (PyTorch)
15
- vit_model = models.vit_b_16(pretrained=False)
16
- vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
17
 
18
- # Load ViT model weights
19
- vit_model_path = "andromeda01111/ViT_BCC"
20
- vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
21
- vit_model.to(device)
22
- vit_model.eval()
23
 
24
- # Define ViT image transformations
25
- transform = transforms.Compose([
26
- transforms.Resize((224, 224)),
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
- ])
30
 
31
  # Class labels
32
  class_names = ["Benign", "Malignant"]
33
 
34
- # Load trained Neural Network model (TensorFlow/Keras)
35
- nn_model_path = "andromeda01111/NN_BC/my_NN_BC_model.keras" # Ensure the correct path
36
- nn_model = tf.keras.models.load_model(nn_model_path)
37
-
38
- # Load scaler for feature normalization
39
- scaler_path = "andromeda01111/NN_BC/nn_bc_scaler.pkl" # Update path
40
- scaler = joblib.load(scaler_path) # Load pre-fitted scaler
41
-
42
- # Define feature names for NN model
43
  feature_names = [
44
  "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
45
  "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
@@ -50,16 +36,15 @@ feature_names = [
50
  ]
51
 
52
  def classify(model_choice, image=None, *features):
53
- """Classify using ViT (image) or NN (features)."""
54
  if model_choice == "ViT":
55
  if image is None:
56
- return "Please upload an image for ViT classification."
57
- image = image.convert("RGB") # Ensure RGB format
58
- input_tensor = transform(image).unsqueeze(0).to(device) # Preprocess image
59
 
60
  with torch.no_grad():
61
- output = vit_model(input_tensor)
62
- predicted_class = torch.argmax(output, dim=1).item()
63
 
64
  return class_names[predicted_class]
65
 
@@ -67,33 +52,24 @@ def classify(model_choice, image=None, *features):
67
  if any(f is None for f in features):
68
  return "Please enter all 30 numerical features."
69
 
70
- # Convert input features to NumPy array
71
  input_data = np.array(features).reshape(1, -1)
72
-
73
- # Standardize using pre-trained scaler
74
  input_data_std = scaler.transform(input_data)
75
-
76
- # Run prediction using TensorFlow model
77
  prediction = nn_model.predict(input_data_std)
78
  predicted_class = np.argmax(prediction)
79
 
80
  return class_names[predicted_class]
81
 
82
- # Define Gradio UI components
83
  model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
84
  image_input = gr.Image(type="pil", label="Upload Mammogram Image")
 
85
 
86
- # Feature inputs labeled correctly
87
- feature_inputs = [gr.Number(label=feature_names[i]) for i in range(30)]
88
-
89
- # Gradio Interface
90
  iface = gr.Interface(
91
  fn=classify,
92
- inputs=[model_selector, image_input] + feature_inputs, # Image + Feature inputs
93
  outputs="text",
94
  title="Breast Cancer Classification",
95
- description="Choose between ViT (image-based) and Neural Network (feature-based) classification."
96
  )
97
 
98
- # launch app
99
  iface.launch()
 
5
  import numpy as np
6
  import tensorflow as tf
7
  from PIL import Image
8
+ from sklearn.preprocessing import StandardScaler
9
+ import joblib
10
+ from transformers import ViTForImageClassification, ViTImageProcessor
11
 
12
+ # Set device
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ # Load ViT model from Hugging Face
16
+ vit_model = ViTForImageClassification.from_pretrained("andromeda01111/ViT_BCC").to(device)
17
+ vit_processor = ViTImageProcessor.from_pretrained("andromeda01111/ViT_BCC")
18
 
19
+ # Load Neural Network model from Hugging Face
20
+ nn_model = tf.keras.models.load_model("andromeda01111/NN_BC")
 
 
 
21
 
22
+ # Load scaler (ensure it's uploaded in the Hugging Face repo)
23
+ scaler = joblib.load("scaler.pkl")
 
 
 
 
24
 
25
  # Class labels
26
  class_names = ["Benign", "Malignant"]
27
 
28
+ # Define feature names
 
 
 
 
 
 
 
 
29
  feature_names = [
30
  "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
31
  "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
 
36
  ]
37
 
38
  def classify(model_choice, image=None, *features):
 
39
  if model_choice == "ViT":
40
  if image is None:
41
+ return "Please upload an image."
42
+ image = image.convert("RGB")
43
+ inputs = vit_processor(images=image, return_tensors="pt").to(device)
44
 
45
  with torch.no_grad():
46
+ outputs = vit_model(**inputs)
47
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
48
 
49
  return class_names[predicted_class]
50
 
 
52
  if any(f is None for f in features):
53
  return "Please enter all 30 numerical features."
54
 
 
55
  input_data = np.array(features).reshape(1, -1)
 
 
56
  input_data_std = scaler.transform(input_data)
 
 
57
  prediction = nn_model.predict(input_data_std)
58
  predicted_class = np.argmax(prediction)
59
 
60
  return class_names[predicted_class]
61
 
62
+ # Gradio UI
63
  model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
64
  image_input = gr.Image(type="pil", label="Upload Mammogram Image")
65
+ feature_inputs = [gr.Number(label=feature) for feature in feature_names]
66
 
 
 
 
 
67
  iface = gr.Interface(
68
  fn=classify,
69
+ inputs=[model_selector, image_input] + feature_inputs,
70
  outputs="text",
71
  title="Breast Cancer Classification",
72
+ description="Choose ViT (image-based) or Neural Network (feature-based) classification."
73
  )
74
 
 
75
  iface.launch()