Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
|
|
6 |
from tensorflow.keras import layers, models
|
7 |
from tensorflow.keras.applications import EfficientNetB0
|
8 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
9 |
-
|
10 |
import io
|
11 |
|
12 |
# Set page config
|
@@ -42,14 +42,17 @@ st.markdown("""
|
|
42 |
</style>
|
43 |
""", unsafe_allow_html=True)
|
44 |
|
|
|
45 |
@st.cache_resource
|
46 |
-
def
|
47 |
-
"""Load the trained model"""
|
48 |
try:
|
49 |
-
|
|
|
|
|
50 |
except Exception as e:
|
51 |
-
st.error(f"Error loading model: {str(e)}")
|
52 |
-
return None
|
53 |
|
54 |
def color_histogram(image, bins=16):
|
55 |
"""Calculate color histogram features"""
|
@@ -57,7 +60,6 @@ def color_histogram(image, bins=16):
|
|
57 |
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
|
58 |
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
|
59 |
|
60 |
-
# Normalize histograms
|
61 |
hist_r = hist_r / (np.sum(hist_r) + 1e-7)
|
62 |
hist_g = hist_g / (np.sum(hist_g) + 1e-7)
|
63 |
hist_b = hist_b / (np.sum(hist_b) + 1e-7)
|
@@ -72,7 +74,7 @@ def color_moments(image):
|
|
72 |
for i in range(3):
|
73 |
channel = img[:,:,i]
|
74 |
mean = np.mean(channel)
|
75 |
-
std = np.std(channel) + 1e-7
|
76 |
skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
|
77 |
moments.extend([mean, std, skewness])
|
78 |
|
@@ -91,7 +93,7 @@ def dominant_color_descriptor(image, k=3):
|
|
91 |
percentages = counts / len(labels)
|
92 |
return np.concatenate([centers.flatten(), percentages])
|
93 |
except Exception:
|
94 |
-
return np.zeros(k * 4)
|
95 |
|
96 |
def color_coherence_vector(image, k=3):
|
97 |
"""Calculate color coherence vector"""
|
@@ -107,13 +109,12 @@ def color_coherence_vector(image, k=3):
|
|
107 |
total_pixels = np.sum(region_mask)
|
108 |
ccv.extend([total_pixels, total_pixels])
|
109 |
|
110 |
-
# Pad with zeros if needed
|
111 |
ccv.extend([0] * (2 * k - len(ccv)))
|
112 |
return np.array(ccv[:2*k])
|
113 |
|
114 |
@st.cache_resource
|
115 |
-
def
|
116 |
-
"""Create and cache the feature extractor
|
117 |
input_shape = (256, 256, 3)
|
118 |
inputs = layers.Input(shape=input_shape)
|
119 |
x = layers.Lambda(preprocess_input)(inputs)
|
@@ -129,23 +130,30 @@ def create_feature_extractor():
|
|
129 |
|
130 |
def extract_features(image):
|
131 |
"""Extract all features from an image"""
|
132 |
-
#
|
133 |
-
|
|
|
|
|
|
|
134 |
|
135 |
-
|
136 |
-
hist_features = color_histogram(image_uint8)
|
137 |
-
moment_features = color_moments(image_uint8)
|
138 |
-
dominant_features = dominant_color_descriptor(image_uint8)
|
139 |
-
ccv_features = color_coherence_vector(image_uint8)
|
140 |
-
|
141 |
-
return np.concatenate([
|
142 |
hist_features,
|
143 |
moment_features,
|
144 |
dominant_features,
|
145 |
ccv_features
|
146 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
def preprocess_image(image):
|
149 |
"""Preprocess the uploaded image"""
|
150 |
# Convert to RGB if needed
|
151 |
if image.mode != 'RGB':
|
@@ -156,25 +164,13 @@ def preprocess_image(image):
|
|
156 |
img_array = cv2.resize(img_array, (256, 256))
|
157 |
img_array = img_array.astype('float32') / 255.0
|
158 |
|
159 |
-
# Extract
|
160 |
-
|
161 |
|
162 |
-
#
|
163 |
-
|
164 |
-
deep_features = feature_extractor.predict(
|
165 |
-
np.expand_dims(img_array, axis=0),
|
166 |
-
verbose=0
|
167 |
-
)
|
168 |
-
|
169 |
-
# Combine features
|
170 |
-
combined_features = np.concatenate([
|
171 |
-
traditional_features.reshape(1, -1),
|
172 |
-
deep_features.reshape(1, -1)
|
173 |
-
], axis=1)
|
174 |
|
175 |
-
|
176 |
-
scaler = StandardScaler()
|
177 |
-
return scaler.fit_transform(combined_features)
|
178 |
|
179 |
def get_top_predictions(prediction, class_names, top_k=5):
|
180 |
"""Get top k predictions with their probabilities"""
|
@@ -188,6 +184,12 @@ def main():
|
|
188 |
st.title("🪨 Stone Classification")
|
189 |
st.write("Upload an image of a stone to classify its type")
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
# Initialize session state
|
192 |
if 'predictions' not in st.session_state:
|
193 |
st.session_state.predictions = None
|
@@ -204,12 +206,7 @@ def main():
|
|
204 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
205 |
|
206 |
with st.spinner('Analyzing image...'):
|
207 |
-
|
208 |
-
if model is None:
|
209 |
-
st.error("Failed to load model")
|
210 |
-
return
|
211 |
-
|
212 |
-
processed_image = preprocess_image(image)
|
213 |
prediction = model.predict(processed_image, verbose=0)
|
214 |
|
215 |
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
|
|
|
6 |
from tensorflow.keras import layers, models
|
7 |
from tensorflow.keras.applications import EfficientNetB0
|
8 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
9 |
+
import joblib
|
10 |
import io
|
11 |
|
12 |
# Set page config
|
|
|
42 |
</style>
|
43 |
""", unsafe_allow_html=True)
|
44 |
|
45 |
+
# Cache the model loading
|
46 |
@st.cache_resource
|
47 |
+
def load_model_and_scaler():
|
48 |
+
"""Load the trained model and scaler"""
|
49 |
try:
|
50 |
+
model = tf.keras.models.load_model('mlp_model.h5')
|
51 |
+
scaler = joblib.load('scaler.save')
|
52 |
+
return model, scaler
|
53 |
except Exception as e:
|
54 |
+
st.error(f"Error loading model or scaler: {str(e)}")
|
55 |
+
return None, None
|
56 |
|
57 |
def color_histogram(image, bins=16):
|
58 |
"""Calculate color histogram features"""
|
|
|
60 |
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
|
61 |
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
|
62 |
|
|
|
63 |
hist_r = hist_r / (np.sum(hist_r) + 1e-7)
|
64 |
hist_g = hist_g / (np.sum(hist_g) + 1e-7)
|
65 |
hist_b = hist_b / (np.sum(hist_b) + 1e-7)
|
|
|
74 |
for i in range(3):
|
75 |
channel = img[:,:,i]
|
76 |
mean = np.mean(channel)
|
77 |
+
std = np.std(channel) + 1e-7
|
78 |
skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
|
79 |
moments.extend([mean, std, skewness])
|
80 |
|
|
|
93 |
percentages = counts / len(labels)
|
94 |
return np.concatenate([centers.flatten(), percentages])
|
95 |
except Exception:
|
96 |
+
return np.zeros(k * 4)
|
97 |
|
98 |
def color_coherence_vector(image, k=3):
|
99 |
"""Calculate color coherence vector"""
|
|
|
109 |
total_pixels = np.sum(region_mask)
|
110 |
ccv.extend([total_pixels, total_pixels])
|
111 |
|
|
|
112 |
ccv.extend([0] * (2 * k - len(ccv)))
|
113 |
return np.array(ccv[:2*k])
|
114 |
|
115 |
@st.cache_resource
|
116 |
+
def create_vit_feature_extractor():
|
117 |
+
"""Create and cache the ViT feature extractor"""
|
118 |
input_shape = (256, 256, 3)
|
119 |
inputs = layers.Input(shape=input_shape)
|
120 |
x = layers.Lambda(preprocess_input)(inputs)
|
|
|
130 |
|
131 |
def extract_features(image):
|
132 |
"""Extract all features from an image"""
|
133 |
+
# Traditional features
|
134 |
+
hist_features = color_histogram(image)
|
135 |
+
moment_features = color_moments(image)
|
136 |
+
dominant_features = dominant_color_descriptor(image)
|
137 |
+
ccv_features = color_coherence_vector(image)
|
138 |
|
139 |
+
traditional_features = np.concatenate([
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
hist_features,
|
141 |
moment_features,
|
142 |
dominant_features,
|
143 |
ccv_features
|
144 |
])
|
145 |
+
|
146 |
+
# Deep features using ViT
|
147 |
+
feature_extractor = create_vit_feature_extractor()
|
148 |
+
vit_features = feature_extractor.predict(
|
149 |
+
np.expand_dims(image, axis=0),
|
150 |
+
verbose=0
|
151 |
+
)
|
152 |
+
|
153 |
+
# Combine all features
|
154 |
+
return np.concatenate([traditional_features, vit_features.flatten()])
|
155 |
|
156 |
+
def preprocess_image(image, scaler):
|
157 |
"""Preprocess the uploaded image"""
|
158 |
# Convert to RGB if needed
|
159 |
if image.mode != 'RGB':
|
|
|
164 |
img_array = cv2.resize(img_array, (256, 256))
|
165 |
img_array = img_array.astype('float32') / 255.0
|
166 |
|
167 |
+
# Extract all features
|
168 |
+
features = extract_features(img_array)
|
169 |
|
170 |
+
# Scale features using the provided scaler
|
171 |
+
scaled_features = scaler.transform(features.reshape(1, -1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
return scaled_features
|
|
|
|
|
174 |
|
175 |
def get_top_predictions(prediction, class_names, top_k=5):
|
176 |
"""Get top k predictions with their probabilities"""
|
|
|
184 |
st.title("🪨 Stone Classification")
|
185 |
st.write("Upload an image of a stone to classify its type")
|
186 |
|
187 |
+
# Load model and scaler
|
188 |
+
model, scaler = load_model_and_scaler()
|
189 |
+
if model is None or scaler is None:
|
190 |
+
st.error("Failed to load model or scaler. Please ensure both files exist.")
|
191 |
+
return
|
192 |
+
|
193 |
# Initialize session state
|
194 |
if 'predictions' not in st.session_state:
|
195 |
st.session_state.predictions = None
|
|
|
206 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
207 |
|
208 |
with st.spinner('Analyzing image...'):
|
209 |
+
processed_image = preprocess_image(image, scaler)
|
|
|
|
|
|
|
|
|
|
|
210 |
prediction = model.predict(processed_image, verbose=0)
|
211 |
|
212 |
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
|