svsaurav95 commited on
Commit
744e6f4
·
verified ·
1 Parent(s): 41e3148

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +35 -48
src/streamlit_app.py CHANGED
@@ -3,23 +3,16 @@ import torch
3
  import torch.nn as nn
4
  import timm
5
  import numpy as np
6
- import cv2
7
  from PIL import Image
8
- import io
 
 
 
 
9
 
10
- # Hide Streamlit warnings and UI elements
11
- st.set_page_config(layout="wide")
12
- st.markdown("""
13
- <style>
14
- footer {visibility: hidden;}
15
- </style>
16
- """, unsafe_allow_html=True)
17
-
18
- MODEL_PATH = "mobilevit_teeth_segmentation.pth"
19
-
20
- # === Model Definition ===
21
  class MobileViTSegmentation(nn.Module):
22
- def __init__(self, encoder_name='mobilevit_s', pretrained=False):
23
  super().__init__()
24
  self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
25
  self.encoder_channels = self.backbone.feature_info.channels()
@@ -41,55 +34,49 @@ class MobileViTSegmentation(nn.Module):
41
  out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
42
  return out
43
 
44
- # === Load Model ===
45
  @st.cache_resource
46
  def load_model():
47
- model = MobileViTSegmentation()
48
- state_dict = torch.load( MODEL_PATH, map_location="cpu")
49
- model.load_state_dict(state_dict)
50
  model.eval()
51
  return model
52
 
53
  model = load_model()
54
 
55
- # === Preprocessing ===
56
- def preprocess_image(image: Image.Image):
57
- image = image.convert("RGB").resize((256, 256))
58
- arr = np.array(image).astype(np.float32) / 255.0
59
- arr = np.transpose(arr, (2, 0, 1)) # HWC → CHW
60
- tensor = torch.tensor(arr).unsqueeze(0) # Add batch dim
61
- return tensor
62
-
63
- # === Postprocessing: Overlay Mask ===
64
- def overlay_mask(image_pil, mask_tensor, threshold=0.7):
65
- image = np.array(image_pil.resize((256, 256)))
66
- mask = mask_tensor.squeeze().detach().numpy()
67
- mask_bin = (mask > threshold).astype(np.uint8) * 255
68
 
69
- mask_color = np.zeros_like(image)
70
- mask_color[..., 2] = mask_bin # Blue mask
71
 
72
- overlayed = cv2.addWeighted(image, 1.0, mask_color, 0.5, 0)
73
- return overlayed
74
-
75
- # === UI ===
76
- st.title("🦷 Tooth Segmentation with MobileViT")
77
- st.write("Upload an image to segment the **visible teeth area** using a lightweight MobileViT segmentation model.")
78
-
79
- uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
80
 
81
  if uploaded_file:
82
- image = Image.open(uploaded_file)
83
- tensor = preprocess_image(image)
 
 
 
84
 
85
- with st.spinner("Segmenting..."):
86
- with torch.no_grad():
87
- pred = model(tensor)[0]
88
 
89
- overlayed_img = overlay_mask(image, pred)
 
 
 
 
 
90
 
 
91
  col1, col2 = st.columns(2)
92
  with col1:
93
  st.image(image, caption="Original Image", use_container_width=True)
94
  with col2:
95
- st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)
 
3
  import torch.nn as nn
4
  import timm
5
  import numpy as np
 
6
  from PIL import Image
7
+ import requests
8
+ from io import BytesIO
9
+ import torchvision.transforms as T
10
+ import matplotlib.pyplot as plt
11
+ from huggingface_hub import hf_hub_download
12
 
13
+ # ========== Model Definition ==========
 
 
 
 
 
 
 
 
 
 
14
  class MobileViTSegmentation(nn.Module):
15
+ def __init__(self, encoder_name='mobilevit_s', pretrained=True):
16
  super().__init__()
17
  self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
18
  self.encoder_channels = self.backbone.feature_info.channels()
 
34
  out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
35
  return out
36
 
37
+ # ========== Load Model ==========
38
  @st.cache_resource
39
  def load_model():
40
+ checkpoint_path = hf_hub_download(repo_id="svsaurav95/ToothSegmentation", filename="mobilevit_teeth_segmentation.pth")
41
+ model = MobileViTSegmentation(pretrained=False)
42
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
43
  model.eval()
44
  return model
45
 
46
  model = load_model()
47
 
48
+ # ========== Image Transformation ==========
49
+ transform = T.Compose([
50
+ T.Resize((256, 256)),
51
+ T.ToTensor()
52
+ ])
 
 
 
 
 
 
 
 
53
 
54
+ # ========== Streamlit UI ==========
55
+ st.title("Tooth Segmentation using MobileViT")
56
 
57
+ uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
58
 
59
  if uploaded_file:
60
+ image = Image.open(uploaded_file).convert("RGB")
61
+ input_tensor = transform(image).unsqueeze(0)
62
+
63
+ with torch.no_grad():
64
+ pred_mask = model(input_tensor)[0, 0].numpy()
65
 
66
+ # Post-processing
67
+ pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255
68
+ pred_mask = Image.fromarray(pred_mask).resize(image.size)
69
 
70
+ # Create overlay
71
+ overlay = Image.new("RGBA", image.size, (0, 0, 255, 100)) # Blue translucent
72
+ base = image.convert("RGBA")
73
+ pred_mask_rgba = Image.new("L", image.size, 0)
74
+ pred_mask_rgba.paste(255, mask=pred_mask)
75
+ final = Image.composite(overlay, base, pred_mask_rgba)
76
 
77
+ # Side-by-side display
78
  col1, col2 = st.columns(2)
79
  with col1:
80
  st.image(image, caption="Original Image", use_container_width=True)
81
  with col2:
82
+ st.image(final, caption="Tooth Segmentation Overlay", use_container_width=True)