Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -184,46 +184,46 @@ def create_vit_feature_extractor(input_shape=(256, 256, 3), num_classes=None):
|
|
184 |
|
185 |
def preprocess_image(image):
|
186 |
"""Preprocess the uploaded image"""
|
187 |
-
#
|
188 |
-
|
189 |
-
|
|
|
190 |
# Convert to numpy array
|
191 |
img_array = np.array(image)
|
192 |
|
193 |
-
#
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
# # Preprocess image similar to training
|
200 |
-
# img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
|
201 |
-
# img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
|
202 |
-
# img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
|
203 |
-
|
204 |
-
# # Adjust brightness
|
205 |
-
# target_brightness = 150
|
206 |
-
# current_brightness = np.mean(img_array)
|
207 |
-
# alpha = target_brightness / (current_brightness + 1e-5)
|
208 |
-
# img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
|
209 |
-
|
210 |
-
# # Apply Gaussian blur
|
211 |
-
# img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
|
212 |
|
213 |
# Resize
|
214 |
img_array = cv2.resize(img_array, (256, 256))
|
215 |
|
216 |
# Normalize
|
217 |
img_array = img_array.astype('float32') / 255.0
|
|
|
|
|
218 |
image_features = extract_features(img_array)
|
|
|
|
|
219 |
vit_extractor = create_vit_feature_extractor()
|
220 |
-
|
221 |
-
#
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
scaler = StandardScaler()
|
225 |
image_scaled = scaler.fit_transform(image_combined)
|
226 |
-
|
|
|
227 |
|
228 |
def get_top_predictions(prediction, class_names, top_k=5):
|
229 |
"""Get top k predictions with their probabilities"""
|
@@ -267,8 +267,11 @@ def main():
|
|
267 |
# Preprocess image
|
268 |
processed_image = preprocess_image(image)
|
269 |
|
|
|
|
|
|
|
270 |
# Make prediction
|
271 |
-
prediction = model.predict(
|
272 |
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
|
273 |
|
274 |
# Get top 5 predictions
|
|
|
184 |
|
185 |
def preprocess_image(image):
|
186 |
"""Preprocess the uploaded image"""
|
187 |
+
# Convert to RGB if needed
|
188 |
+
if image.mode != 'RGB':
|
189 |
+
image = image.convert('RGB')
|
190 |
+
|
191 |
# Convert to numpy array
|
192 |
img_array = np.array(image)
|
193 |
|
194 |
+
# Ensure RGB format
|
195 |
+
if len(img_array.shape) == 2: # Grayscale
|
196 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
|
197 |
+
elif img_array.shape[2] == 4: # RGBA
|
198 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
# Resize
|
201 |
img_array = cv2.resize(img_array, (256, 256))
|
202 |
|
203 |
# Normalize
|
204 |
img_array = img_array.astype('float32') / 255.0
|
205 |
+
|
206 |
+
# Extract traditional features
|
207 |
image_features = extract_features(img_array)
|
208 |
+
|
209 |
+
# Create and process ViT features
|
210 |
vit_extractor = create_vit_feature_extractor()
|
211 |
+
|
212 |
+
# Reshape image for ViT processing - THIS IS THE KEY FIX
|
213 |
+
img_for_vit = np.expand_dims(img_array, axis=0) # Add batch dimension
|
214 |
+
image_vit = vit_extractor.predict(img_for_vit)
|
215 |
+
|
216 |
+
# Flatten ViT features if needed
|
217 |
+
image_vit = image_vit.reshape(1, -1) # Ensure 2D shape
|
218 |
+
|
219 |
+
# Combine features
|
220 |
+
image_combined = np.concatenate([image_features.reshape(1, -1), image_vit], axis=1)
|
221 |
+
|
222 |
+
# Scale features
|
223 |
scaler = StandardScaler()
|
224 |
image_scaled = scaler.fit_transform(image_combined)
|
225 |
+
|
226 |
+
return image_scaled.squeeze() # Remove any unnecessary dimensions
|
227 |
|
228 |
def get_top_predictions(prediction, class_names, top_k=5):
|
229 |
"""Get top k predictions with their probabilities"""
|
|
|
267 |
# Preprocess image
|
268 |
processed_image = preprocess_image(image)
|
269 |
|
270 |
+
# Ensure correct shape for prediction
|
271 |
+
processed_image = np.expand_dims(processed_image, axis=0)
|
272 |
+
|
273 |
# Make prediction
|
274 |
+
prediction = model.predict(processed_image)
|
275 |
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
|
276 |
|
277 |
# Get top 5 predictions
|