svsaurav95 commited on
Commit
174f191
·
verified ·
1 Parent(s): aac18fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -35
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import torch
3
  import torch.nn as nn
@@ -6,11 +7,13 @@ import timm
6
  import numpy as np
7
  import cv2
8
  from PIL import Image
9
- import io
10
  import warnings
11
 
12
  warnings.filterwarnings("ignore")
13
 
 
 
 
14
  # Define the model class
15
  class MobileViTSegmentation(nn.Module):
16
  def __init__(self, encoder_name='mobilevit_s', pretrained=False):
@@ -34,53 +37,70 @@ class MobileViTSegmentation(nn.Module):
34
  out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
35
  return out
36
 
37
- # Load model function
38
  @st.cache_resource
39
  def load_model():
40
- model = MobileViTSegmentation()
41
- model.load_state_dict(torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu"))
42
- model.eval()
43
- return model
 
 
 
 
 
44
 
45
- # Inference
46
  def predict_mask(image, model, threshold=0.7):
47
- transform = transforms.Compose([
48
- transforms.Resize((256, 256)),
49
- transforms.ToTensor()
50
- ])
51
- img_tensor = transform(image).unsqueeze(0)
52
- with torch.no_grad():
53
- pred = model(img_tensor)
54
- pred_mask = pred.squeeze().numpy()
55
- pred_mask = (pred_mask > threshold).astype(np.uint8)
56
- return pred_mask
 
 
 
 
57
 
58
  # Overlay mask on image
59
  def overlay_mask(image, mask, color=(0, 0, 255), alpha=0.4):
60
- image_np = np.array(image.convert("RGB"))
61
- mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
62
- color_mask = np.zeros_like(image_np)
63
- color_mask[:, :] = color
64
- overlay = np.where(mask_resized[..., None] == 1, color_mask, 0)
65
- blended = cv2.addWeighted(image_np, 1 - alpha, overlay, alpha, 0)
66
- return blended
 
 
 
 
67
 
68
  # Streamlit UI
 
69
  st.title("🦷 Tooth Segmentation from Mouth Images")
70
- st.markdown("Upload a face or mouth image and get the segmented **tooth region overlayed**.")
71
 
72
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
73
 
74
  if uploaded_file:
75
- image = Image.open(uploaded_file).convert("RGB")
76
- model = load_model()
77
- pred_mask = predict_mask(image, model)
78
-
79
- overlayed_img = overlay_mask(image, pred_mask, color=(0, 0, 255), alpha=0.4)
80
 
81
- col1, col2 = st.columns(2)
82
- with col1:
83
- st.image(image, caption="Original Image", use_container_width=True)
84
- with col2:
85
- st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)
86
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import streamlit as st
3
  import torch
4
  import torch.nn as nn
 
7
  import numpy as np
8
  import cv2
9
  from PIL import Image
 
10
  import warnings
11
 
12
  warnings.filterwarnings("ignore")
13
 
14
+ # Optional: Turn off file watchers in HF Spaces to avoid torch-related warnings
15
+ os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
16
+
17
  # Define the model class
18
  class MobileViTSegmentation(nn.Module):
19
  def __init__(self, encoder_name='mobilevit_s', pretrained=False):
 
37
  out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
38
  return out
39
 
40
+ # Load model function with spinner and error handling
41
  @st.cache_resource
42
  def load_model():
43
+ try:
44
+ with st.spinner("Loading model..."):
45
+ model = MobileViTSegmentation()
46
+ model.load_state_dict(torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu"))
47
+ model.eval()
48
+ return model
49
+ except Exception as e:
50
+ st.error(f"❌ Failed to load model: {e}")
51
+ st.stop()
52
 
53
+ # Inference function
54
  def predict_mask(image, model, threshold=0.7):
55
+ try:
56
+ transform = transforms.Compose([
57
+ transforms.Resize((256, 256)),
58
+ transforms.ToTensor()
59
+ ])
60
+ img_tensor = transform(image).unsqueeze(0)
61
+ with torch.no_grad():
62
+ pred = model(img_tensor)
63
+ pred_mask = pred.squeeze().numpy()
64
+ pred_mask = (pred_mask > threshold).astype(np.uint8)
65
+ return pred_mask
66
+ except Exception as e:
67
+ st.error(f"❌ Prediction failed: {e}")
68
+ return None
69
 
70
  # Overlay mask on image
71
  def overlay_mask(image, mask, color=(0, 0, 255), alpha=0.4):
72
+ try:
73
+ image_np = np.array(image.convert("RGB"))
74
+ mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
75
+ color_mask = np.zeros_like(image_np)
76
+ color_mask[:, :] = color
77
+ overlay = np.where(mask_resized[..., None] == 1, color_mask, 0)
78
+ blended = cv2.addWeighted(image_np, 1 - alpha, overlay, alpha, 0)
79
+ return blended
80
+ except Exception as e:
81
+ st.error(f"❌ Mask overlay failed: {e}")
82
+ return np.array(image)
83
 
84
  # Streamlit UI
85
+ st.set_page_config(page_title="Tooth Segmentation", layout="wide")
86
  st.title("🦷 Tooth Segmentation from Mouth Images")
87
+ st.markdown("Upload a **face or mouth image**, and this app will overlay the **predicted tooth segmentation mask**.")
88
 
89
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
90
 
91
  if uploaded_file:
92
+ try:
93
+ image = Image.open(uploaded_file).convert("RGB")
94
+ model = load_model()
95
+ pred_mask = predict_mask(image, model)
 
96
 
97
+ if pred_mask is not None:
98
+ overlayed_img = overlay_mask(image, pred_mask, color=(0, 0, 255), alpha=0.4)
 
 
 
99
 
100
+ col1, col2 = st.columns(2)
101
+ with col1:
102
+ st.image(image, caption="Original Image", use_container_width=True)
103
+ with col2:
104
+ st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)
105
+ except Exception as e:
106
+ st.error(f"❌ Error processing image: {e}")