SonFox2920's picture
Update app.py
d58b258 verified
raw
history blame
11.1 kB
import streamlit as st
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
import joblib
import io
import os
# Add Cloudinary import
import cloudinary
import cloudinary.uploader
from cloudinary.utils import cloudinary_url
# Cloudinary Configuration
cloudinary.config(
cloud_name = os.getenv("CLOUD"),
api_key = os.getenv("API"),
api_secret = os.getenv("SECRET"),
secure=True
)
def upload_to_cloudinary(file_path, label):
"""
Upload file to Cloudinary with specified label as folder
"""
try:
# Upload to Cloudinary
upload_result = cloudinary.uploader.upload(
file_path,
folder=label,
public_id=f"{label}_{os.path.basename(file_path)}"
)
# Generate optimized URLs
optimize_url, _ = cloudinary_url(
upload_result['public_id'],
fetch_format="auto",
quality="auto"
)
auto_crop_url, _ = cloudinary_url(
upload_result['public_id'],
width=500,
height=500,
crop="auto",
gravity="auto"
)
return {
"upload_result": upload_result,
"optimize_url": optimize_url,
"auto_crop_url": auto_crop_url
}
except Exception as e:
return f"Error uploading to Cloudinary: {str(e)}"
def main():
st.title("🨨 Phân loại đá")
st.write("Tải lên hình ảnh của một viên đá để phân loại loại của nó.")
# Load model and scaler
model, scaler = load_model_and_scaler()
if model is None or scaler is None:
st.error("Không thể tải mô hình hoặc bộ chuẩn hóa. Vui lòng đảm bảo rằng cả hai tệp đều tồn tại.")
return
# Initialize session state
if 'predictions' not in st.session_state:
st.session_state.predictions = None
if 'uploaded_image' not in st.session_state:
st.session_state.uploaded_image = None
col1, col2 = st.columns(2)
with col1:
st.subheader("Tải lên Hình ảnh")
uploaded_file = st.file_uploader("Chọn hình ảnh...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
image = Image.open(uploaded_file)
st.image(image, caption="Hình ảnh đã tải lên", use_column_width=True)
st.session_state.uploaded_image = image
with st.spinner('Đang phân tích hình ảnh...'):
processed_image = preprocess_image(image, scaler)
prediction = model.predict(processed_image, verbose=0)
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
st.session_state.predictions = get_top_predictions(prediction, class_names)
except Exception as e:
st.error(f"Lỗi khi xử lý hình ảnh: {str(e)}")
with col2:
st.subheader("Kết quả Dự đoán")
if st.session_state.predictions:
# Display main prediction
top_class, top_confidence = st.session_state.predictions[0]
st.markdown(
f"""
<div class='prediction-card'>
<h3>Dự đoán chính: Màu {top_class}</h3>
<h3>Độ tin cậy: {top_confidence:.2f}%</h3>
</div>
""",
unsafe_allow_html=True
)
# Display confidence bar
st.progress(top_confidence / 100)
# Display top 5 predictions
st.markdown("### 5 Dự đoán hàng đầu")
st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
for class_name, confidence in st.session_state.predictions:
st.markdown(
f"**Màu {class_name}: Độ tin cậy {confidence:.2f}%**"
)
st.progress(confidence / 100)
st.markdown("</div>", unsafe_allow_html=True)
# User Confirmation Section
st.markdown("### Xác nhận độ chính xác của mô hình")
st.write("Giúp chúng tôi cải thiện mô hình bằng cách xác nhận độ chính xác của dự đoán.")
# Accuracy Radio Button
accuracy_option = st.radio(
"Dự đoán có chính xác không?",
["Chọn", "Chính xác", "Không chính xác"],
index=0
)
if accuracy_option == "Không chính xác":
# Input for correct grade
correct_grade = st.selectbox(
"Chọn màu đá đúng:",
['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'],
index=None,
placeholder="Chọn màu đúng"
)
# Upload button
if st.button("Tải lên Hình ảnh để sửa chữa"):
if correct_grade and st.session_state.uploaded_image:
# Save the image temporarily
temp_image_path = f"temp_image_{hash(uploaded_file.name)}.png"
st.session_state.uploaded_image.save(temp_image_path)
try:
# Upload to Cloudinary
cloudinary_result = upload_to_cloudinary(temp_image_path, correct_grade)
if isinstance(cloudinary_result, dict):
st.success(f"Hình ảnh đã được tải lên thành công cho màu {correct_grade}")
st.write(f"URL công khai: {cloudinary_result['upload_result']['secure_url']}")
else:
st.error(cloudinary_result)
# Clean up temporary file
os.remove(temp_image_path)
except Exception as e:
st.error(f"Tải lên thất bại: {str(e)}")
else:
st.warning("Vui lòng chọn màu đúng trước khi tải lên.")
else:
st.info("Tải lên hình ảnh để xem các dự đoán.")
st.markdown("---")
st.markdown("Tạo bởi ❤️ với Streamlit")
def load_model_and_scaler():
"""Load the trained model and scaler"""
try:
model = tf.keras.models.load_model('mlp_model.h5')
# Tải scaler đã lưu
scaler = joblib.load('standard_scaler.pkl')
return model, scaler
except Exception as e:
st.error(f"Error loading model or scaler: {str(e)}")
return None, None
def color_histogram(image, bins=16):
"""Calculate color histogram features"""
hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
hist_r = hist_r / (np.sum(hist_r) + 1e-7)
hist_g = hist_g / (np.sum(hist_g) + 1e-7)
hist_b = hist_b / (np.sum(hist_b) + 1e-7)
return np.concatenate([hist_r, hist_g, hist_b])
def color_moments(image):
"""Calculate color moments features"""
img = image.astype(np.float32) / 255.0
moments = []
for i in range(3):
channel = img[:,:,i]
mean = np.mean(channel)
std = np.std(channel) + 1e-7
skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
moments.extend([mean, std, skewness])
return np.array(moments)
def dominant_color_descriptor(image, k=3):
"""Calculate dominant color descriptor"""
pixels = image.reshape(-1, 3).astype(np.float32)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
flags = cv2.KMEANS_RANDOM_CENTERS
try:
_, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
unique, counts = np.unique(labels, return_counts=True)
percentages = counts / len(labels)
return np.concatenate([centers.flatten(), percentages])
except Exception:
return np.zeros(k * 4)
def color_coherence_vector(image, k=3):
"""Calculate color coherence vector"""
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
gray = np.uint8(gray)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
num_labels, labels = cv2.connectedComponents(binary)
ccv = []
for i in range(1, min(k+1, num_labels)):
region_mask = (labels == i)
total_pixels = np.sum(region_mask)
ccv.extend([total_pixels, total_pixels])
ccv.extend([0] * (2 * k - len(ccv)))
return np.array(ccv[:2*k])
@st.cache_resource
def create_vit_feature_extractor():
"""Create and cache the ViT feature extractor"""
input_shape = (256, 256, 3)
inputs = layers.Input(shape=input_shape)
x = layers.Lambda(preprocess_input)(inputs)
base_model = EfficientNetB0(
include_top=False,
weights='imagenet',
input_tensor=x
)
x = layers.GlobalAveragePooling2D()(base_model.output)
return models.Model(inputs=inputs, outputs=x)
def extract_features(image):
"""Extract all features from an image"""
# Traditional features
hist_features = color_histogram(image)
moment_features = color_moments(image)
dominant_features = dominant_color_descriptor(image)
ccv_features = color_coherence_vector(image)
traditional_features = np.concatenate([
hist_features,
moment_features,
dominant_features,
ccv_features
])
# Deep features using ViT
feature_extractor = create_vit_feature_extractor()
vit_features = feature_extractor.predict(
np.expand_dims(image, axis=0),
verbose=0
)
# Combine all features
return np.concatenate([traditional_features, vit_features.flatten()])
def preprocess_image(image, scaler):
"""Preprocess the uploaded image"""
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert to numpy array and resize
img_array = np.array(image)
img_array = cv2.resize(img_array, (256, 256))
img_array = img_array.astype('float32') / 255.0
# Extract all features
features = extract_features(img_array)
# Scale features using the provided scaler
scaled_features = scaler.transform(features.reshape(1, -1))
return scaled_features
def get_top_predictions(prediction, class_names):
# Extract the top 5 predictions with confidence values
probabilities = tf.nn.softmax(prediction[0]).numpy()
top_indices = np.argsort(probabilities)[-5:][::-1]
return [(class_names[i], probabilities[i] * 100) for i in top_indices]
if __name__ == "__main__":
main()