alperugurcan commited on
Commit
1e860b2
Β·
verified Β·
1 Parent(s): fc00212

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -18
app.py CHANGED
@@ -1,30 +1,52 @@
1
  import streamlit as st
2
- import torch, torch.nn as nn
 
3
  import numpy as np
4
  from huggingface_hub import hf_hub_download
5
 
6
  class IcebergClassifier(nn.Module):
7
  def __init__(self):
8
  super().__init__()
9
- self.conv = nn.Sequential(nn.Conv2d(2,16,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
10
- nn.Conv2d(16,32,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
11
- nn.Conv2d(32,64,3,padding=1),nn.ReLU(),nn.MaxPool2d(2))
12
- self.fc = nn.Sequential(nn.Linear(64*9*9,64),nn.ReLU(),nn.Dropout(0.5),
13
- nn.Linear(64,1),nn.Sigmoid())
14
- def forward(self, x): return self.fc(self.conv(x).view(x.size(0),-1))
 
 
 
 
15
 
16
  @st.cache_resource
17
- def predict(b1,b2):
18
  model = IcebergClassifier().eval()
19
- model.load_state_dict(torch.load(hf_hub_download("alperugurcan/iceberg","best_iceberg_model.pth"),map_location='cpu'))
20
- return model(torch.FloatTensor(np.stack([np.array(b1).reshape(75,75),np.array(b2).reshape(75,75)])).unsqueeze(0)).item()
21
 
22
- st.title('🧊 Iceberg vs Ship Classifier')
23
- examples = {"Ship":[[-32.5]*75]*75,"Iceberg":[[-28.3]*75]*75,"Strong":[[-35.8]*75]*75}
24
- for k,v in examples.items(): st.code(f"Example ({k}): {v}")
25
- b1,b2 = st.text_area('Band 1'),st.text_area('Band 2')
26
- if st.button('Predict'):
 
 
27
  try:
28
- p = predict(eval(b1),eval(b2))
29
- st.write(f"{'🧊 ICEBERG' if p>0.5 else '🚒 SHIP'} ({p:.1%})")
30
- except: st.error('Invalid input')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
  import numpy as np
5
  from huggingface_hub import hf_hub_download
6
 
7
  class IcebergClassifier(nn.Module):
8
  def __init__(self):
9
  super().__init__()
10
+ self.conv = nn.Sequential(
11
+ nn.Conv2d(2,16,3), nn.ReLU(), nn.MaxPool2d(2),
12
+ nn.Conv2d(16,32,3), nn.ReLU(), nn.MaxPool2d(2)
13
+ )
14
+ self.fc = nn.Sequential(
15
+ nn.Linear(32*17*17,64), nn.ReLU(),
16
+ nn.Linear(64,1), nn.Sigmoid()
17
+ )
18
+ def forward(self, x):
19
+ return self.fc(self.conv(x).view(x.size(0),-1))
20
 
21
  @st.cache_resource
22
+ def load_model():
23
  model = IcebergClassifier().eval()
24
+ model.load_state_dict(torch.load(hf_hub_download("alperugurcan/iceberg","best_iceberg_model.pth"), map_location='cpu'))
25
+ return model
26
 
27
+ st.title('🧊 Simple Ship vs Iceberg Detector')
28
+
29
+ # Simple numeric inputs
30
+ band1 = st.number_input('Enter Band 1 value (-40 to -20)', -40.0, -20.0, -30.0)
31
+ band2 = st.number_input('Enter Band 2 value (-35 to -15)', -35.0, -15.0, -25.0)
32
+
33
+ if st.button('Detect'):
34
  try:
35
+ # Create simple 75x75 arrays with the input values
36
+ b1 = np.full((75,75), band1)
37
+ b2 = np.full((75,75), band2)
38
+
39
+ # Prepare input tensor
40
+ x = torch.FloatTensor(np.stack([b1,b2])).unsqueeze(0)
41
+
42
+ # Get prediction
43
+ model = load_model()
44
+ with torch.no_grad():
45
+ pred = model(x).item()
46
+
47
+ # Show result
48
+ result = "🧊 ICEBERG" if pred > 0.5 else "🚒 SHIP"
49
+ st.success(f"{result} ({pred:.1%})")
50
+
51
+ except Exception as e:
52
+ st.error(f'Error: {str(e)}')