Update main.py
Browse files
main.py
CHANGED
@@ -1,69 +1,67 @@
|
|
1 |
-
from flask import Flask, request, render_template
|
2 |
-
import numpy as np
|
3 |
-
import os
|
4 |
-
from PIL import Image
|
5 |
-
import tensorflow as tf
|
6 |
-
import os
|
7 |
-
os.environ["SM_FRAMEWORK"] = "tf.keras"
|
8 |
-
from tensorflow.keras.models import load_model
|
9 |
-
import segmentation_models as sm
|
10 |
-
from tensorflow.keras import backend as K
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
|
13 |
-
app = Flask(__name__)
|
14 |
-
app.config['UPLOAD_FOLDER'] = 'static/uploads'
|
15 |
-
|
16 |
-
# Load your model
|
17 |
-
weights = [0.1666] * 6 # Adjust if necessary
|
18 |
-
dice_loss = sm.losses.DiceLoss(class_weights=weights)
|
19 |
-
focal_loss = sm.losses.CategoricalFocalLoss()
|
20 |
-
total_loss = dice_loss + (1 * focal_loss)
|
21 |
-
|
22 |
-
def jacard_coef(y_true, y_pred):
|
23 |
-
y_true_f = K.flatten(y_true)
|
24 |
-
y_pred_f = K.flatten(y_pred)
|
25 |
-
intersection = K.sum(y_true_f * y_pred_f)
|
26 |
-
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
|
27 |
-
|
28 |
-
model_path = r"C:\Users\sriva\Videos\dubai segmentation\satellite_standard_unet_100epochs.hdf5"
|
29 |
-
custom_objects = {
|
30 |
-
"dice_loss_plus_1focal_loss": total_loss,
|
31 |
-
"jacard_coef": jacard_coef
|
32 |
-
}
|
33 |
-
model = load_model(model_path, custom_objects=custom_objects)
|
34 |
-
|
35 |
-
@app.route('/')
|
36 |
-
def index():
|
37 |
-
return render_template('index.html')
|
38 |
-
|
39 |
-
@app.route('/predict', methods=['POST'])
|
40 |
-
def predict():
|
41 |
-
if 'file' not in request.files:
|
42 |
-
return "No file uploaded.", 400
|
43 |
-
file = request.files['file']
|
44 |
-
if file.filename == '':
|
45 |
-
return "No selected file.", 400
|
46 |
-
|
47 |
-
# Save the uploaded image
|
48 |
-
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
|
49 |
-
file.save(file_path)
|
50 |
-
|
51 |
-
# Load and preprocess the image
|
52 |
-
test_img = Image.open(file_path)
|
53 |
-
desired_width = 256
|
54 |
-
desired_height = 256
|
55 |
-
test_img = test_img.resize((desired_width, desired_height))
|
56 |
-
test_img = np.array(test_img)
|
57 |
-
test_img_input = np.expand_dims(test_img, 0)
|
58 |
-
|
59 |
-
# Make the prediction
|
60 |
-
prediction = model.predict(test_img_input)
|
61 |
-
predicted_img = np.argmax(prediction, axis=3)[0, :, :]
|
62 |
-
|
63 |
-
# Save the predicted image
|
64 |
-
plt.imsave(os.path.join(app.config['UPLOAD_FOLDER'], 'predicted_' + file.filename), predicted_img, cmap='jet')
|
65 |
-
|
66 |
-
return render_template('index.html', uploaded_image=file.filename, predicted_image='predicted_' + file.filename)
|
67 |
-
|
68 |
-
if __name__ == '__main__':
|
69 |
-
app.run(debug=True)
|
|
|
1 |
+
from flask import Flask, request, render_template
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import tensorflow as tf
|
6 |
+
import os
|
7 |
+
os.environ["SM_FRAMEWORK"] = "tf.keras"
|
8 |
+
from tensorflow.keras.models import load_model
|
9 |
+
import segmentation_models as sm
|
10 |
+
from tensorflow.keras import backend as K
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
app = Flask(__name__)
|
14 |
+
app.config['UPLOAD_FOLDER'] = 'static/uploads'
|
15 |
+
|
16 |
+
# Load your model
|
17 |
+
weights = [0.1666] * 6 # Adjust if necessary
|
18 |
+
dice_loss = sm.losses.DiceLoss(class_weights=weights)
|
19 |
+
focal_loss = sm.losses.CategoricalFocalLoss()
|
20 |
+
total_loss = dice_loss + (1 * focal_loss)
|
21 |
+
|
22 |
+
def jacard_coef(y_true, y_pred):
|
23 |
+
y_true_f = K.flatten(y_true)
|
24 |
+
y_pred_f = K.flatten(y_pred)
|
25 |
+
intersection = K.sum(y_true_f * y_pred_f)
|
26 |
+
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
|
27 |
+
|
28 |
+
model_path = r"C:\Users\sriva\Videos\dubai segmentation\satellite_standard_unet_100epochs.hdf5"
|
29 |
+
custom_objects = {
|
30 |
+
"dice_loss_plus_1focal_loss": total_loss,
|
31 |
+
"jacard_coef": jacard_coef
|
32 |
+
}
|
33 |
+
model = load_model(model_path, custom_objects=custom_objects)
|
34 |
+
|
35 |
+
@app.route('/')
|
36 |
+
def index():
|
37 |
+
return render_template('index.html')
|
38 |
+
|
39 |
+
@app.route('/predict', methods=['POST'])
|
40 |
+
def predict():
|
41 |
+
if 'file' not in request.files:
|
42 |
+
return "No file uploaded.", 400
|
43 |
+
file = request.files['file']
|
44 |
+
if file.filename == '':
|
45 |
+
return "No selected file.", 400
|
46 |
+
|
47 |
+
# Save the uploaded image
|
48 |
+
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
|
49 |
+
file.save(file_path)
|
50 |
+
|
51 |
+
# Load and preprocess the image
|
52 |
+
test_img = Image.open(file_path)
|
53 |
+
desired_width = 256
|
54 |
+
desired_height = 256
|
55 |
+
test_img = test_img.resize((desired_width, desired_height))
|
56 |
+
test_img = np.array(test_img)
|
57 |
+
test_img_input = np.expand_dims(test_img, 0)
|
58 |
+
|
59 |
+
# Make the prediction
|
60 |
+
prediction = model.predict(test_img_input)
|
61 |
+
predicted_img = np.argmax(prediction, axis=3)[0, :, :]
|
62 |
+
|
63 |
+
# Save the predicted image
|
64 |
+
plt.imsave(os.path.join(app.config['UPLOAD_FOLDER'], 'predicted_' + file.filename), predicted_img, cmap='jet')
|
65 |
+
|
66 |
+
return render_template('index.html', uploaded_image=file.filename, predicted_image='predicted_' + file.filename)
|
67 |
+
|
|
|
|