quantumbit commited on
Commit
c8d22c3
·
verified ·
1 Parent(s): 3f54158

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ import re
8
+ import joblib
9
+ import os
10
+
11
+ app = Flask(__name__)
12
+
13
+ # Load all models - use absolute paths for Hugging Face
14
+ MODEL_DIR = os.path.join(os.getcwd(), "models")
15
+ models = {
16
+ "cnn": tf.keras.models.load_model(os.path.join(MODEL_DIR, "mnist_cnn_model.h5")),
17
+ "svm": joblib.load(os.path.join(MODEL_DIR, "mnist_svm.pkl")),
18
+ "logistic": joblib.load(os.path.join(MODEL_DIR, "mnist_logistic_regression.pkl")),
19
+ "random_forest": joblib.load(os.path.join(MODEL_DIR, "mnist_random_forest.pkl"))
20
+ }
21
+
22
+ # [Keep your existing classification_reports, preprocess_image,
23
+ # and create_simulated_scores functions exactly as they are]
24
+ # Classification reports for each model
25
+ classification_reports = {
26
+ "cnn": """
27
+ precision recall f1-score support
28
+ 0 0.99 1.00 0.99 980
29
+ 1 1.00 1.00 1.00 1135
30
+ 2 0.99 0.99 0.99 1032
31
+ 3 0.99 1.00 0.99 1010
32
+ 4 1.00 0.99 0.99 982
33
+ 5 0.98 0.99 0.99 892
34
+ 6 1.00 0.98 0.99 958
35
+ 7 0.99 0.99 0.99 1028
36
+ 8 1.00 0.99 0.99 974
37
+ 9 0.99 0.99 0.99 1009
38
+ accuracy 0.99 10000
39
+ macro avg 0.99 0.99 0.99 10000
40
+ weighted avg 0.99 0.99 0.99 10000
41
+ """,
42
+ "svm": """
43
+ precision recall f1-score support
44
+ 0 0.9874 0.9896 0.9885 1343
45
+ 1 0.9882 0.9925 0.9903 1600
46
+ 2 0.9706 0.9819 0.9762 1380
47
+ 3 0.9783 0.9749 0.9766 1433
48
+ 4 0.9777 0.9822 0.9800 1295
49
+ 5 0.9827 0.9796 0.9811 1273
50
+ 6 0.9858 0.9921 0.9889 1396
51
+ 7 0.9768 0.9807 0.9788 1503
52
+ 8 0.9813 0.9683 0.9748 1357
53
+ 9 0.9807 0.9669 0.9738 1420
54
+ accuracy 0.9810 14000
55
+ macro avg 0.9809 0.9809 0.9809 14000
56
+ weighted avg 0.9810 0.9810 0.9810 14000
57
+ """,
58
+ "random_forest": """
59
+ precision recall f1-score support
60
+ 0 0.9844 0.9866 0.9855 1343
61
+ 1 0.9831 0.9831 0.9831 1600
62
+ 2 0.9522 0.9674 0.9597 1380
63
+ 3 0.9579 0.9532 0.9556 1433
64
+ 4 0.9617 0.9699 0.9658 1295
65
+ 5 0.9707 0.9631 0.9669 1273
66
+ 6 0.9800 0.9828 0.9814 1396
67
+ 7 0.9668 0.9681 0.9674 1503
68
+ 8 0.9599 0.9528 0.9564 1357
69
+ 9 0.9566 0.9465 0.9515 1420
70
+ accuracy 0.9675 14000
71
+ macro avg 0.9673 0.9674 0.9673 14000
72
+ weighted avg 0.9675 0.9675 0.9675 14000
73
+ """,
74
+ "logistic": """
75
+ precision recall f1-score support
76
+ 0 0.9636 0.9650 0.9643 1343
77
+ 1 0.9433 0.9675 0.9553 1600
78
+ 2 0.9113 0.8935 0.9023 1380
79
+ 3 0.9021 0.8939 0.8980 1433
80
+ 4 0.9225 0.9290 0.9257 1295
81
+ 5 0.8846 0.8790 0.8818 1273
82
+ 6 0.9420 0.9534 0.9477 1396
83
+ 7 0.9273 0.9421 0.9347 1503
84
+ 8 0.8973 0.8696 0.8832 1357
85
+ 9 0.9019 0.9000 0.9010 1420
86
+ accuracy 0.9204 14000
87
+ macro avg 0.9196 0.9193 0.9194 14000
88
+ weighted avg 0.9201 0.9204 0.9202 14000
89
+ """
90
+ }
91
+
92
+ # Preprocess image before prediction
93
+ def preprocess_image(image, model_type):
94
+ image = image.resize((28, 28)).convert('L')
95
+ img_array = np.array(image) / 255.0
96
+
97
+ if model_type == "cnn":
98
+ return np.expand_dims(np.expand_dims(img_array, axis=0), axis=-1)
99
+ else:
100
+ return img_array.flatten().reshape(1, -1)
101
+
102
+ def create_simulated_scores(predicted_digit):
103
+ scores = [0.01] * 10
104
+ remaining = 1.0 - sum(scores)
105
+ scores[predicted_digit] += remaining
106
+ return scores
107
+
108
+ @app.route('/')
109
+ def home():
110
+ return jsonify({
111
+ "message": "MNIST Classifier API",
112
+ "available_models": list(models.keys()),
113
+ "endpoints": {
114
+ "/predict": "POST - Send image and model_type",
115
+ "/get_classification_report": "POST - Get model metrics"
116
+ }
117
+ })
118
+
119
+ # [Keep your existing /get_classification_report and /predict routes exactly as they are]
120
+ @app.route('/get_classification_report', methods=['POST'])
121
+ def get_classification_report():
122
+ model_type = request.json['model_type']
123
+ if model_type in classification_reports:
124
+ return jsonify({'report': classification_reports[model_type]})
125
+ return jsonify({'error': 'Model not found'})
126
+
127
+ @app.route('/predict', methods=['POST'])
128
+ def predict():
129
+ try:
130
+ data = request.json['image']
131
+ model_type = request.json['model_type']
132
+
133
+ # Process image directly without saving
134
+ img_data = re.sub('^data:image/png;base64,', '', data)
135
+ img = Image.open(io.BytesIO(base64.b64decode(img_data)))
136
+ processed_image = preprocess_image(img, model_type)
137
+
138
+ if model_type not in models:
139
+ return jsonify({'error': 'Model not found'})
140
+
141
+ model = models[model_type]
142
+
143
+ if model_type == "cnn":
144
+ prediction = model.predict(processed_image)
145
+ predicted_digit = np.argmax(prediction)
146
+ confidence_scores = prediction[0].tolist()
147
+ score_type = "probability"
148
+
149
+ elif model_type == "svm":
150
+ predicted_digit = model.predict(processed_image)[0]
151
+ if hasattr(model, "decision_function"):
152
+ try:
153
+ decision_scores = model.decision_function(processed_image)
154
+ if len(decision_scores.shape) == 2:
155
+ confidence_scores = decision_scores[0].tolist()
156
+ else:
157
+ confidence_scores = [0] * 10
158
+ for i in range(10):
159
+ confidence_scores[i] = sum(1 for score in decision_scores[0] if score > 0)
160
+ min_score = min(confidence_scores)
161
+ if min_score < 0:
162
+ confidence_scores = [score - min_score for score in confidence_scores]
163
+ score_type = "decision_distance"
164
+ except Exception:
165
+ confidence_scores = create_simulated_scores(int(predicted_digit))
166
+ score_type = "simulated"
167
+ else:
168
+ confidence_scores = create_simulated_scores(int(predicted_digit))
169
+ score_type = "simulated"
170
+
171
+ else:
172
+ predicted_digit = model.predict(processed_image)[0]
173
+ if hasattr(model, "predict_proba"):
174
+ try:
175
+ confidence_scores = model.predict_proba(processed_image)[0].tolist()
176
+ score_type = "probability"
177
+ except Exception:
178
+ confidence_scores = create_simulated_scores(int(predicted_digit))
179
+ score_type = "simulated"
180
+ else:
181
+ confidence_scores = create_simulated_scores(int(predicted_digit))
182
+ score_type = "simulated"
183
+
184
+ return jsonify({
185
+ 'digit': int(predicted_digit),
186
+ 'confidence_scores': confidence_scores,
187
+ 'score_type': score_type
188
+ })
189
+
190
+ except Exception as e:
191
+ return jsonify({'error': str(e)})
192
+
193
+ if __name__ == '__main__':
194
+ app.run(host='0.0.0.0', port=7860)