srivatsavdamaraju commited on
Commit
eddf030
·
verified ·
1 Parent(s): 4105b6d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -69
main.py CHANGED
@@ -1,69 +1,67 @@
1
- from flask import Flask, request, render_template
2
- import numpy as np
3
- import os
4
- from PIL import Image
5
- import tensorflow as tf
6
- import os
7
- os.environ["SM_FRAMEWORK"] = "tf.keras"
8
- from tensorflow.keras.models import load_model
9
- import segmentation_models as sm
10
- from tensorflow.keras import backend as K
11
- import matplotlib.pyplot as plt
12
-
13
- app = Flask(__name__)
14
- app.config['UPLOAD_FOLDER'] = 'static/uploads'
15
-
16
- # Load your model
17
- weights = [0.1666] * 6 # Adjust if necessary
18
- dice_loss = sm.losses.DiceLoss(class_weights=weights)
19
- focal_loss = sm.losses.CategoricalFocalLoss()
20
- total_loss = dice_loss + (1 * focal_loss)
21
-
22
- def jacard_coef(y_true, y_pred):
23
- y_true_f = K.flatten(y_true)
24
- y_pred_f = K.flatten(y_pred)
25
- intersection = K.sum(y_true_f * y_pred_f)
26
- return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
27
-
28
- model_path = r"C:\Users\sriva\Videos\dubai segmentation\satellite_standard_unet_100epochs.hdf5"
29
- custom_objects = {
30
- "dice_loss_plus_1focal_loss": total_loss,
31
- "jacard_coef": jacard_coef
32
- }
33
- model = load_model(model_path, custom_objects=custom_objects)
34
-
35
- @app.route('/')
36
- def index():
37
- return render_template('index.html')
38
-
39
- @app.route('/predict', methods=['POST'])
40
- def predict():
41
- if 'file' not in request.files:
42
- return "No file uploaded.", 400
43
- file = request.files['file']
44
- if file.filename == '':
45
- return "No selected file.", 400
46
-
47
- # Save the uploaded image
48
- file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
49
- file.save(file_path)
50
-
51
- # Load and preprocess the image
52
- test_img = Image.open(file_path)
53
- desired_width = 256
54
- desired_height = 256
55
- test_img = test_img.resize((desired_width, desired_height))
56
- test_img = np.array(test_img)
57
- test_img_input = np.expand_dims(test_img, 0)
58
-
59
- # Make the prediction
60
- prediction = model.predict(test_img_input)
61
- predicted_img = np.argmax(prediction, axis=3)[0, :, :]
62
-
63
- # Save the predicted image
64
- plt.imsave(os.path.join(app.config['UPLOAD_FOLDER'], 'predicted_' + file.filename), predicted_img, cmap='jet')
65
-
66
- return render_template('index.html', uploaded_image=file.filename, predicted_image='predicted_' + file.filename)
67
-
68
- if __name__ == '__main__':
69
- app.run(debug=True)
 
1
+ from flask import Flask, request, render_template
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ import tensorflow as tf
6
+ import os
7
+ os.environ["SM_FRAMEWORK"] = "tf.keras"
8
+ from tensorflow.keras.models import load_model
9
+ import segmentation_models as sm
10
+ from tensorflow.keras import backend as K
11
+ import matplotlib.pyplot as plt
12
+
13
+ app = Flask(__name__)
14
+ app.config['UPLOAD_FOLDER'] = 'static/uploads'
15
+
16
+ # Load your model
17
+ weights = [0.1666] * 6 # Adjust if necessary
18
+ dice_loss = sm.losses.DiceLoss(class_weights=weights)
19
+ focal_loss = sm.losses.CategoricalFocalLoss()
20
+ total_loss = dice_loss + (1 * focal_loss)
21
+
22
+ def jacard_coef(y_true, y_pred):
23
+ y_true_f = K.flatten(y_true)
24
+ y_pred_f = K.flatten(y_pred)
25
+ intersection = K.sum(y_true_f * y_pred_f)
26
+ return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
27
+
28
+ model_path = r"C:\Users\sriva\Videos\dubai segmentation\satellite_standard_unet_100epochs.hdf5"
29
+ custom_objects = {
30
+ "dice_loss_plus_1focal_loss": total_loss,
31
+ "jacard_coef": jacard_coef
32
+ }
33
+ model = load_model(model_path, custom_objects=custom_objects)
34
+
35
+ @app.route('/')
36
+ def index():
37
+ return render_template('index.html')
38
+
39
+ @app.route('/predict', methods=['POST'])
40
+ def predict():
41
+ if 'file' not in request.files:
42
+ return "No file uploaded.", 400
43
+ file = request.files['file']
44
+ if file.filename == '':
45
+ return "No selected file.", 400
46
+
47
+ # Save the uploaded image
48
+ file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
49
+ file.save(file_path)
50
+
51
+ # Load and preprocess the image
52
+ test_img = Image.open(file_path)
53
+ desired_width = 256
54
+ desired_height = 256
55
+ test_img = test_img.resize((desired_width, desired_height))
56
+ test_img = np.array(test_img)
57
+ test_img_input = np.expand_dims(test_img, 0)
58
+
59
+ # Make the prediction
60
+ prediction = model.predict(test_img_input)
61
+ predicted_img = np.argmax(prediction, axis=3)[0, :, :]
62
+
63
+ # Save the predicted image
64
+ plt.imsave(os.path.join(app.config['UPLOAD_FOLDER'], 'predicted_' + file.filename), predicted_img, cmap='jet')
65
+
66
+ return render_template('index.html', uploaded_image=file.filename, predicted_image='predicted_' + file.filename)
67
+