svsaurav95 commited on
Commit
ea20fb6
·
verified ·
1 Parent(s): 78f86c4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +11 -13
src/streamlit_app.py CHANGED
@@ -4,12 +4,10 @@ import torch.nn as nn
4
  import timm
5
  import numpy as np
6
  from PIL import Image
7
- import requests
8
- from io import BytesIO
9
  import torchvision.transforms as T
10
- import matplotlib.pyplot as plt
11
  from huggingface_hub import hf_hub_download
12
 
 
13
  class MobileViTSegmentation(nn.Module):
14
  def __init__(self, encoder_name='mobilevit_s', pretrained=True):
15
  super().__init__()
@@ -36,13 +34,12 @@ class MobileViTSegmentation(nn.Module):
36
  # ========== Load Model ==========
37
  @st.cache_resource
38
  def load_model():
39
- from model import MobileViTSegmentation # or define it here if needed
40
- cache_dir = "/tmp/huggingface" # Writable path in HF Spaces
41
 
42
  checkpoint_path = hf_hub_download(
43
  repo_id="svsaurav95/ToothSegmentation",
44
  filename="mobilevit_teeth_segmentation.pth",
45
- cache_dir=cache_dir # <- this fixes the permission issue
46
  )
47
 
48
  model = MobileViTSegmentation(pretrained=False)
@@ -52,14 +49,15 @@ def load_model():
52
 
53
  model = load_model()
54
 
55
- # ========== Image Transformation ==========
56
  transform = T.Compose([
57
  T.Resize((256, 256)),
58
  T.ToTensor()
59
  ])
60
 
61
- # ========== Streamlit UI ==========
62
- st.title("Tooth Segmentation using MobileViT")
 
63
 
64
  uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"])
65
 
@@ -70,12 +68,12 @@ if uploaded_file:
70
  with torch.no_grad():
71
  pred_mask = model(input_tensor)[0, 0].numpy()
72
 
73
- # Post-processing
74
  pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255
75
  pred_mask = Image.fromarray(pred_mask).resize(image.size)
76
 
77
- # Create overlay
78
- overlay = Image.new("RGBA", image.size, (0, 0, 255, 100))
79
  base = image.convert("RGBA")
80
  pred_mask_rgba = Image.new("L", image.size, 0)
81
  pred_mask_rgba.paste(255, mask=pred_mask)
@@ -86,4 +84,4 @@ if uploaded_file:
86
  with col1:
87
  st.image(image, caption="Original Image", use_container_width=True)
88
  with col2:
89
- st.image(final, caption="Tooth Segmentation Overlay", use_container_width=True)
 
4
  import timm
5
  import numpy as np
6
  from PIL import Image
 
 
7
  import torchvision.transforms as T
 
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # ========== Model Definition ==========
11
  class MobileViTSegmentation(nn.Module):
12
  def __init__(self, encoder_name='mobilevit_s', pretrained=True):
13
  super().__init__()
 
34
  # ========== Load Model ==========
35
  @st.cache_resource
36
  def load_model():
37
+ cache_dir = "/tmp/huggingface" # Safe writable directory in HF Spaces
 
38
 
39
  checkpoint_path = hf_hub_download(
40
  repo_id="svsaurav95/ToothSegmentation",
41
  filename="mobilevit_teeth_segmentation.pth",
42
+ cache_dir=cache_dir
43
  )
44
 
45
  model = MobileViTSegmentation(pretrained=False)
 
49
 
50
  model = load_model()
51
 
52
+ # ========== Image Preprocessing ==========
53
  transform = T.Compose([
54
  T.Resize((256, 256)),
55
  T.ToTensor()
56
  ])
57
 
58
+ # ========== UI ==========
59
+ st.set_page_config(page_title="Tooth Segmentation", layout="wide")
60
+ st.title("🦷 Tooth Segmentation using MobileViT")
61
 
62
  uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"])
63
 
 
68
  with torch.no_grad():
69
  pred_mask = model(input_tensor)[0, 0].numpy()
70
 
71
+ # Threshold and resize to original
72
  pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255
73
  pred_mask = Image.fromarray(pred_mask).resize(image.size)
74
 
75
+ # Create translucent blue overlay
76
+ overlay = Image.new("RGBA", image.size, (0, 0, 255, 100))
77
  base = image.convert("RGBA")
78
  pred_mask_rgba = Image.new("L", image.size, 0)
79
  pred_mask_rgba.paste(255, mask=pred_mask)
 
84
  with col1:
85
  st.image(image, caption="Original Image", use_container_width=True)
86
  with col2:
87
+ st.image(final, caption="Tooth Area Segmentation", use_container_width=True)