quantumbit's picture
Update app.py
4e6f262 verified
raw
history blame
4.29 kB
from flask import Flask, request, jsonify
import numpy as np
import tensorflow as tf
from PIL import Image
import io
import base64
import re
import joblib
import os
app = Flask(__name__)
# Load all models - use absolute paths for Hugging Face
MODEL_DIR = os.path.join(os.getcwd(), "models")
models = {
"cnn": tf.keras.models.load_model(os.path.join(MODEL_DIR, "mnist_cnn_model.h5")),
"svm": joblib.load(os.path.join(MODEL_DIR, "mnist_svm.pkl")),
"logistic": joblib.load(os.path.join(MODEL_DIR, "mnist_logistic_regression.pkl")),
"random_forest": joblib.load(os.path.join(MODEL_DIR, "mnist_random_forest.pkl"))
}
# Preprocess image before prediction
def preprocess_image(image, model_type):
image = image.resize((28, 28)).convert('L')
img_array = np.array(image) / 255.0
if model_type == "cnn":
return np.expand_dims(np.expand_dims(img_array, axis=0), axis=-1)
else:
return img_array.flatten().reshape(1, -1)
def create_simulated_scores(predicted_digit):
scores = [0.01] * 10
remaining = 1.0 - sum(scores)
scores[predicted_digit] += remaining
return scores
@app.route('/')
def home():
return jsonify({
"message": "MNIST Classifier API",
"available_models": list(models.keys()),
"endpoints": {
"/predict": "POST - Send image and model_type",
"/get_classification_report": "POST - Get model metrics"
}
})
@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.json['image']
model_type = request.json['model_type']
# Process image directly without saving
img_data = re.sub('^data:image/png;base64,', '', data)
img = Image.open(io.BytesIO(base64.b64decode(img_data)))
processed_image = preprocess_image(img, model_type)
if model_type not in models:
return jsonify({'error': 'Model not found'})
model = models[model_type]
if model_type == "cnn":
prediction = model.predict(processed_image)
predicted_digit = np.argmax(prediction)
confidence_scores = prediction[0].tolist()
score_type = "probability"
elif model_type == "svm":
predicted_digit = model.predict(processed_image)[0]
if hasattr(model, "decision_function"):
try:
decision_scores = model.decision_function(processed_image)
if len(decision_scores.shape) == 2:
confidence_scores = decision_scores[0].tolist()
else:
confidence_scores = [0] * 10
for i in range(10):
confidence_scores[i] = sum(1 for score in decision_scores[0] if score > 0)
min_score = min(confidence_scores)
if min_score < 0:
confidence_scores = [score - min_score for score in confidence_scores]
score_type = "decision_distance"
except Exception:
confidence_scores = create_simulated_scores(int(predicted_digit))
score_type = "simulated"
else:
confidence_scores = create_simulated_scores(int(predicted_digit))
score_type = "simulated"
else:
predicted_digit = model.predict(processed_image)[0]
if hasattr(model, "predict_proba"):
try:
confidence_scores = model.predict_proba(processed_image)[0].tolist()
score_type = "probability"
except Exception:
confidence_scores = create_simulated_scores(int(predicted_digit))
score_type = "simulated"
else:
confidence_scores = create_simulated_scores(int(predicted_digit))
score_type = "simulated"
return jsonify({
'digit': int(predicted_digit),
'confidence_scores': confidence_scores,
'score_type': score_type
})
except Exception as e:
return jsonify({'error': str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)