svsaurav95 commited on
Commit
aac18fb
·
verified ·
1 Parent(s): e4d8b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -46
app.py CHANGED
@@ -1,27 +1,22 @@
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
 
4
  import timm
5
  import numpy as np
6
  import cv2
7
  from PIL import Image
8
  import io
 
9
 
10
- # Hide Streamlit warnings and UI elements
11
- st.set_page_config(layout="wide")
12
- st.markdown("""
13
- <style>
14
- footer {visibility: hidden;}
15
- </style>
16
- """, unsafe_allow_html=True)
17
 
18
- # === Model Definition ===
19
  class MobileViTSegmentation(nn.Module):
20
  def __init__(self, encoder_name='mobilevit_s', pretrained=False):
21
  super().__init__()
22
  self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
23
  self.encoder_channels = self.backbone.feature_info.channels()
24
-
25
  self.decoder = nn.Sequential(
26
  nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
27
  nn.Upsample(scale_factor=2, mode='bilinear'),
@@ -39,55 +34,53 @@ class MobileViTSegmentation(nn.Module):
39
  out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
40
  return out
41
 
42
- # === Load Model ===
43
  @st.cache_resource
44
  def load_model():
45
  model = MobileViTSegmentation()
46
- state_dict = torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu")
47
- model.load_state_dict(state_dict)
48
  model.eval()
49
  return model
50
 
51
- model = load_model()
52
-
53
- # === Preprocessing ===
54
- def preprocess_image(image: Image.Image):
55
- image = image.convert("RGB").resize((256, 256))
56
- arr = np.array(image).astype(np.float32) / 255.0
57
- arr = np.transpose(arr, (2, 0, 1)) # HWC → CHW
58
- tensor = torch.tensor(arr).unsqueeze(0) # Add batch dim
59
- return tensor
60
-
61
- # === Postprocessing: Overlay Mask ===
62
- def overlay_mask(image_pil, mask_tensor, threshold=0.7):
63
- image = np.array(image_pil.resize((256, 256)))
64
- mask = mask_tensor.squeeze().detach().numpy()
65
- mask_bin = (mask > threshold).astype(np.uint8) * 255
66
-
67
- mask_color = np.zeros_like(image)
68
- mask_color[..., 2] = mask_bin # Blue mask
69
-
70
- overlayed = cv2.addWeighted(image, 1.0, mask_color, 0.5, 0)
71
- return overlayed
72
-
73
- # === UI ===
74
- st.title("🦷 Tooth Segmentation with MobileViT")
75
- st.write("Upload an image to segment the **visible teeth area** using a lightweight MobileViT segmentation model.")
76
-
77
- uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
 
78
 
79
  if uploaded_file:
80
- image = Image.open(uploaded_file)
81
- tensor = preprocess_image(image)
 
82
 
83
- with st.spinner("Segmenting..."):
84
- with torch.no_grad():
85
- pred = model(tensor)[0]
86
-
87
- overlayed_img = overlay_mask(image, pred)
88
 
89
  col1, col2 = st.columns(2)
90
  with col1:
91
  st.image(image, caption="Original Image", use_container_width=True)
92
  with col2:
93
  st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)
 
 
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
  import timm
6
  import numpy as np
7
  import cv2
8
  from PIL import Image
9
  import io
10
+ import warnings
11
 
12
+ warnings.filterwarnings("ignore")
 
 
 
 
 
 
13
 
14
+ # Define the model class
15
  class MobileViTSegmentation(nn.Module):
16
  def __init__(self, encoder_name='mobilevit_s', pretrained=False):
17
  super().__init__()
18
  self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
19
  self.encoder_channels = self.backbone.feature_info.channels()
 
20
  self.decoder = nn.Sequential(
21
  nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
22
  nn.Upsample(scale_factor=2, mode='bilinear'),
 
34
  out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
35
  return out
36
 
37
+ # Load model function
38
  @st.cache_resource
39
  def load_model():
40
  model = MobileViTSegmentation()
41
+ model.load_state_dict(torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu"))
 
42
  model.eval()
43
  return model
44
 
45
+ # Inference
46
+ def predict_mask(image, model, threshold=0.7):
47
+ transform = transforms.Compose([
48
+ transforms.Resize((256, 256)),
49
+ transforms.ToTensor()
50
+ ])
51
+ img_tensor = transform(image).unsqueeze(0)
52
+ with torch.no_grad():
53
+ pred = model(img_tensor)
54
+ pred_mask = pred.squeeze().numpy()
55
+ pred_mask = (pred_mask > threshold).astype(np.uint8)
56
+ return pred_mask
57
+
58
+ # Overlay mask on image
59
+ def overlay_mask(image, mask, color=(0, 0, 255), alpha=0.4):
60
+ image_np = np.array(image.convert("RGB"))
61
+ mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
62
+ color_mask = np.zeros_like(image_np)
63
+ color_mask[:, :] = color
64
+ overlay = np.where(mask_resized[..., None] == 1, color_mask, 0)
65
+ blended = cv2.addWeighted(image_np, 1 - alpha, overlay, alpha, 0)
66
+ return blended
67
+
68
+ # Streamlit UI
69
+ st.title("🦷 Tooth Segmentation from Mouth Images")
70
+ st.markdown("Upload a face or mouth image and get the segmented **tooth region overlayed**.")
71
+
72
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
73
 
74
  if uploaded_file:
75
+ image = Image.open(uploaded_file).convert("RGB")
76
+ model = load_model()
77
+ pred_mask = predict_mask(image, model)
78
 
79
+ overlayed_img = overlay_mask(image, pred_mask, color=(0, 0, 255), alpha=0.4)
 
 
 
 
80
 
81
  col1, col2 = st.columns(2)
82
  with col1:
83
  st.image(image, caption="Original Image", use_container_width=True)
84
  with col2:
85
  st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)
86
+