Tanzeer commited on
Commit
e4b253e
·
1 Parent(s): 3489552

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -38
app.py CHANGED
@@ -1,22 +1,20 @@
1
- import streamlit as st
2
  import cv2
3
  import numpy as np
4
  import torch
5
- from torchvision import transforms, models
6
  from PIL import Image
7
  from TranSalNet_Res import TranSalNet
8
- import torch.nn as nn
9
  from utils.data_process import preprocess_img, postprocess_img
10
 
 
 
 
11
  device = torch.device('cpu')
12
  model = TranSalNet()
13
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
14
  model.to(device)
15
  model.eval()
16
 
17
- import cv2
18
- import numpy as np
19
-
20
  def count_and_label_red_patches(heatmap, threshold=200):
21
  red_mask = heatmap[:, :, 2] > threshold
22
  contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -58,46 +56,49 @@ def count_and_label_red_patches(heatmap, threshold=200):
58
 
59
  return original_image, len(contours)
60
 
 
 
 
 
 
 
 
 
61
 
62
- st.title('Saliency Detection App')
63
- st.write('Upload an image for saliency detection:')
64
- uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
65
-
66
- if uploaded_image:
67
- image = Image.open(uploaded_image)
68
- st.image(image, caption='Uploaded Image', use_column_width=True)
69
 
70
- if st.button('Detect Saliency'):
71
- img = image.resize((384, 288))
72
- img = np.array(img)
73
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
74
- img = np.array(img) / 255.
75
- img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
76
- img = torch.from_numpy(img)
77
- img = img.type(torch.FloatTensor).to(device)
78
 
79
- pred_saliency = model(img).squeeze().detach().numpy()
80
 
81
- heatmap = (pred_saliency * 255).astype(np.uint8)
82
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
 
 
 
83
 
84
- heatmap = cv2.resize(heatmap, (image.width, image.height))
 
85
 
86
- enhanced_image = np.array(image)
87
- b, g, r = cv2.split(enhanced_image)
88
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
89
- b_enhanced = clahe.apply(b)
90
- enhanced_image = cv2.merge((b_enhanced, g, r))
91
 
92
- alpha = 0.7
93
- blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
94
 
95
- original_image, num_red_patches = count_and_label_red_patches(heatmap)
96
 
97
- st.image(original_image, caption=f'Image with {num_red_patches} Red Patches', use_column_width=True, channels='RGB')
 
 
 
 
 
 
98
 
99
- st.image(blended_img, caption='Blended Image', use_column_width=True, channels='BGR')
 
 
100
 
101
- # Create a dir with the name example to save
102
- cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
103
- st.success('Saliency detection complete. Result saved as "example/result15.png".')
 
 
1
  import cv2
2
  import numpy as np
3
  import torch
4
+ from fastapi import FastAPI, UploadFile, File
5
  from PIL import Image
6
  from TranSalNet_Res import TranSalNet
 
7
  from utils.data_process import preprocess_img, postprocess_img
8
 
9
+
10
+ app = FastAPI()
11
+
12
  device = torch.device('cpu')
13
  model = TranSalNet()
14
  model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
15
  model.to(device)
16
  model.eval()
17
 
 
 
 
18
  def count_and_label_red_patches(heatmap, threshold=200):
19
  red_mask = heatmap[:, :, 2] > threshold
20
  contours, _ = cv2.findContours(red_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
56
 
57
  return original_image, len(contours)
58
 
59
+ def process_image(image: Image.Image) -> np.ndarray:
60
+ img = image.resize((384, 288))
61
+ img = np.array(img)
62
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space
63
+ img = np.array(img) / 255.
64
+ img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
65
+ img = torch.from_numpy(img)
66
+ img = img.type(torch.FloatTensor).to(device)
67
 
68
+ pred_saliency = model(img).squeeze().detach().numpy()
 
 
 
 
 
 
69
 
70
+ heatmap = (pred_saliency * 255).astype(np.uint8)
71
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET)
 
 
 
 
 
 
72
 
73
+ heatmap = cv2.resize(heatmap, (image.width, image.height))
74
 
75
+ enhanced_image = np.array(image)
76
+ b, g, r = cv2.split(enhanced_image)
77
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
78
+ b_enhanced = clahe.apply(b)
79
+ enhanced_image = cv2.merge((b_enhanced, g, r))
80
 
81
+ alpha = 0.7
82
+ blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
83
 
84
+ original_image, num_red_patches = count_and_label_red_patches(heatmap)
 
 
 
 
85
 
86
+ # Save processed image (optional)
87
+ cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
88
 
89
+ return blended_img
90
 
91
+ @app.post("/process_image")
92
+ async def process_uploaded_image(file: UploadFile = File(...)):
93
+ try:
94
+ contents = await file.read()
95
+ image = Image.open(io.BytesIO(contents))
96
+ except Exception as e:
97
+ raise HTTPException(status_code=400, detail=f"Error opening image: {str(e)}")
98
 
99
+ try:
100
+ processed_image = process_image(image)
101
+ return StreamingResponse(io.BytesIO(cv2.imencode('.png', processed_image)[1].tobytes()), media_type="image/png")
102
 
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")