SonFox2920 commited on
Commit
33f7f53
·
verified ·
1 Parent(s): d71712b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -45
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  from tensorflow.keras import layers, models
7
  from tensorflow.keras.applications import EfficientNetB0
8
  from tensorflow.keras.applications.efficientnet import preprocess_input
9
- from sklearn.preprocessing import StandardScaler
10
  import io
11
 
12
  # Set page config
@@ -42,14 +42,17 @@ st.markdown("""
42
  </style>
43
  """, unsafe_allow_html=True)
44
 
 
45
  @st.cache_resource
46
- def load_model():
47
- """Load the trained model"""
48
  try:
49
- return tf.keras.models.load_model('mlp_model.h5')
 
 
50
  except Exception as e:
51
- st.error(f"Error loading model: {str(e)}")
52
- return None
53
 
54
  def color_histogram(image, bins=16):
55
  """Calculate color histogram features"""
@@ -57,7 +60,6 @@ def color_histogram(image, bins=16):
57
  hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
58
  hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
59
 
60
- # Normalize histograms
61
  hist_r = hist_r / (np.sum(hist_r) + 1e-7)
62
  hist_g = hist_g / (np.sum(hist_g) + 1e-7)
63
  hist_b = hist_b / (np.sum(hist_b) + 1e-7)
@@ -72,7 +74,7 @@ def color_moments(image):
72
  for i in range(3):
73
  channel = img[:,:,i]
74
  mean = np.mean(channel)
75
- std = np.std(channel) + 1e-7 # Avoid division by zero
76
  skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
77
  moments.extend([mean, std, skewness])
78
 
@@ -91,7 +93,7 @@ def dominant_color_descriptor(image, k=3):
91
  percentages = counts / len(labels)
92
  return np.concatenate([centers.flatten(), percentages])
93
  except Exception:
94
- return np.zeros(k * 4) # Return zero vector if clustering fails
95
 
96
  def color_coherence_vector(image, k=3):
97
  """Calculate color coherence vector"""
@@ -107,13 +109,12 @@ def color_coherence_vector(image, k=3):
107
  total_pixels = np.sum(region_mask)
108
  ccv.extend([total_pixels, total_pixels])
109
 
110
- # Pad with zeros if needed
111
  ccv.extend([0] * (2 * k - len(ccv)))
112
  return np.array(ccv[:2*k])
113
 
114
  @st.cache_resource
115
- def create_feature_extractor():
116
- """Create and cache the feature extractor model"""
117
  input_shape = (256, 256, 3)
118
  inputs = layers.Input(shape=input_shape)
119
  x = layers.Lambda(preprocess_input)(inputs)
@@ -129,23 +130,30 @@ def create_feature_extractor():
129
 
130
  def extract_features(image):
131
  """Extract all features from an image"""
132
- # Convert image to uint8 for OpenCV operations
133
- image_uint8 = (image * 255).astype(np.uint8)
 
 
 
134
 
135
- # Extract traditional features
136
- hist_features = color_histogram(image_uint8)
137
- moment_features = color_moments(image_uint8)
138
- dominant_features = dominant_color_descriptor(image_uint8)
139
- ccv_features = color_coherence_vector(image_uint8)
140
-
141
- return np.concatenate([
142
  hist_features,
143
  moment_features,
144
  dominant_features,
145
  ccv_features
146
  ])
 
 
 
 
 
 
 
 
 
 
147
 
148
- def preprocess_image(image):
149
  """Preprocess the uploaded image"""
150
  # Convert to RGB if needed
151
  if image.mode != 'RGB':
@@ -156,25 +164,13 @@ def preprocess_image(image):
156
  img_array = cv2.resize(img_array, (256, 256))
157
  img_array = img_array.astype('float32') / 255.0
158
 
159
- # Extract traditional features
160
- traditional_features = extract_features(img_array)
161
 
162
- # Extract deep features
163
- feature_extractor = create_feature_extractor()
164
- deep_features = feature_extractor.predict(
165
- np.expand_dims(img_array, axis=0),
166
- verbose=0
167
- )
168
-
169
- # Combine features
170
- combined_features = np.concatenate([
171
- traditional_features.reshape(1, -1),
172
- deep_features.reshape(1, -1)
173
- ], axis=1)
174
 
175
- # Scale features
176
- scaler = StandardScaler()
177
- return scaler.fit_transform(combined_features)
178
 
179
  def get_top_predictions(prediction, class_names, top_k=5):
180
  """Get top k predictions with their probabilities"""
@@ -188,6 +184,12 @@ def main():
188
  st.title("🪨 Stone Classification")
189
  st.write("Upload an image of a stone to classify its type")
190
 
 
 
 
 
 
 
191
  # Initialize session state
192
  if 'predictions' not in st.session_state:
193
  st.session_state.predictions = None
@@ -204,12 +206,7 @@ def main():
204
  st.image(image, caption="Uploaded Image", use_column_width=True)
205
 
206
  with st.spinner('Analyzing image...'):
207
- model = load_model()
208
- if model is None:
209
- st.error("Failed to load model")
210
- return
211
-
212
- processed_image = preprocess_image(image)
213
  prediction = model.predict(processed_image, verbose=0)
214
 
215
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
 
6
  from tensorflow.keras import layers, models
7
  from tensorflow.keras.applications import EfficientNetB0
8
  from tensorflow.keras.applications.efficientnet import preprocess_input
9
+ import joblib
10
  import io
11
 
12
  # Set page config
 
42
  </style>
43
  """, unsafe_allow_html=True)
44
 
45
+ # Cache the model loading
46
  @st.cache_resource
47
+ def load_model_and_scaler():
48
+ """Load the trained model and scaler"""
49
  try:
50
+ model = tf.keras.models.load_model('mlp_model.h5')
51
+ scaler = joblib.load('scaler.save')
52
+ return model, scaler
53
  except Exception as e:
54
+ st.error(f"Error loading model or scaler: {str(e)}")
55
+ return None, None
56
 
57
  def color_histogram(image, bins=16):
58
  """Calculate color histogram features"""
 
60
  hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
61
  hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
62
 
 
63
  hist_r = hist_r / (np.sum(hist_r) + 1e-7)
64
  hist_g = hist_g / (np.sum(hist_g) + 1e-7)
65
  hist_b = hist_b / (np.sum(hist_b) + 1e-7)
 
74
  for i in range(3):
75
  channel = img[:,:,i]
76
  mean = np.mean(channel)
77
+ std = np.std(channel) + 1e-7
78
  skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
79
  moments.extend([mean, std, skewness])
80
 
 
93
  percentages = counts / len(labels)
94
  return np.concatenate([centers.flatten(), percentages])
95
  except Exception:
96
+ return np.zeros(k * 4)
97
 
98
  def color_coherence_vector(image, k=3):
99
  """Calculate color coherence vector"""
 
109
  total_pixels = np.sum(region_mask)
110
  ccv.extend([total_pixels, total_pixels])
111
 
 
112
  ccv.extend([0] * (2 * k - len(ccv)))
113
  return np.array(ccv[:2*k])
114
 
115
  @st.cache_resource
116
+ def create_vit_feature_extractor():
117
+ """Create and cache the ViT feature extractor"""
118
  input_shape = (256, 256, 3)
119
  inputs = layers.Input(shape=input_shape)
120
  x = layers.Lambda(preprocess_input)(inputs)
 
130
 
131
  def extract_features(image):
132
  """Extract all features from an image"""
133
+ # Traditional features
134
+ hist_features = color_histogram(image)
135
+ moment_features = color_moments(image)
136
+ dominant_features = dominant_color_descriptor(image)
137
+ ccv_features = color_coherence_vector(image)
138
 
139
+ traditional_features = np.concatenate([
 
 
 
 
 
 
140
  hist_features,
141
  moment_features,
142
  dominant_features,
143
  ccv_features
144
  ])
145
+
146
+ # Deep features using ViT
147
+ feature_extractor = create_vit_feature_extractor()
148
+ vit_features = feature_extractor.predict(
149
+ np.expand_dims(image, axis=0),
150
+ verbose=0
151
+ )
152
+
153
+ # Combine all features
154
+ return np.concatenate([traditional_features, vit_features.flatten()])
155
 
156
+ def preprocess_image(image, scaler):
157
  """Preprocess the uploaded image"""
158
  # Convert to RGB if needed
159
  if image.mode != 'RGB':
 
164
  img_array = cv2.resize(img_array, (256, 256))
165
  img_array = img_array.astype('float32') / 255.0
166
 
167
+ # Extract all features
168
+ features = extract_features(img_array)
169
 
170
+ # Scale features using the provided scaler
171
+ scaled_features = scaler.transform(features.reshape(1, -1))
 
 
 
 
 
 
 
 
 
 
172
 
173
+ return scaled_features
 
 
174
 
175
  def get_top_predictions(prediction, class_names, top_k=5):
176
  """Get top k predictions with their probabilities"""
 
184
  st.title("🪨 Stone Classification")
185
  st.write("Upload an image of a stone to classify its type")
186
 
187
+ # Load model and scaler
188
+ model, scaler = load_model_and_scaler()
189
+ if model is None or scaler is None:
190
+ st.error("Failed to load model or scaler. Please ensure both files exist.")
191
+ return
192
+
193
  # Initialize session state
194
  if 'predictions' not in st.session_state:
195
  st.session_state.predictions = None
 
206
  st.image(image, caption="Uploaded Image", use_column_width=True)
207
 
208
  with st.spinner('Analyzing image...'):
209
+ processed_image = preprocess_image(image, scaler)
 
 
 
 
 
210
  prediction = model.predict(processed_image, verbose=0)
211
 
212
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']