alperugurcan commited on
Commit
778467b
Β·
verified Β·
1 Parent(s): e568359

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -173
app.py CHANGED
@@ -1,183 +1,30 @@
1
  import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
  import numpy as np
5
- import json
6
- import requests
7
  from huggingface_hub import hf_hub_download
8
 
9
- # Model Architecture
10
  class IcebergClassifier(nn.Module):
11
  def __init__(self):
12
  super().__init__()
13
- self.conv = nn.Sequential(
14
- nn.Conv2d(2, 16, 3, padding=1),
15
- nn.ReLU(),
16
- nn.MaxPool2d(2),
17
- nn.Conv2d(16, 32, 3, padding=1),
18
- nn.ReLU(),
19
- nn.MaxPool2d(2),
20
- nn.Conv2d(32, 64, 3, padding=1),
21
- nn.ReLU(),
22
- nn.MaxPool2d(2)
23
- )
24
- self.fc = nn.Sequential(
25
- nn.Linear(64 * 9 * 9, 64),
26
- nn.ReLU(),
27
- nn.Dropout(0.5),
28
- nn.Linear(64, 1),
29
- nn.Sigmoid()
30
- )
31
 
32
- def forward(self, x):
33
- x = self.conv(x)
34
- x = x.view(x.size(0), -1)
35
- return self.fc(x)
36
-
37
- # Model loading with HuggingFace Hub
38
  @st.cache_resource
39
- def load_model():
40
- try:
41
- # Download the model from HuggingFace Hub
42
- model_path = hf_hub_download(
43
- repo_id="alperugurcan/iceberg",
44
- filename="best_iceberg_model.pth"
45
- )
46
-
47
- model = IcebergClassifier()
48
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
49
- model.eval()
50
- return model
51
- except Exception as e:
52
- st.error(f"Error loading model: {str(e)}")
53
- return None
54
-
55
- def predict(band1, band2):
56
- model = load_model()
57
- if model is None:
58
- return None
59
-
60
  try:
61
- # Transform images to model input format
62
- band1 = np.array(band1).reshape(75, 75)
63
- band2 = np.array(band2).reshape(75, 75)
64
- image = np.stack([band1, band2])
65
- image = torch.FloatTensor(image).unsqueeze(0)
66
-
67
- # Prediction
68
- with torch.no_grad():
69
- pred = model(image)
70
-
71
- return pred.item()
72
- except Exception as e:
73
- st.error(f"Error during prediction: {str(e)}")
74
- return None
75
-
76
- # Streamlit UI
77
- def main():
78
- st.title('🧊 Iceberg vs Ship Classifier')
79
- st.write("""
80
- This application uses satellite radar data to predict whether an image shows an iceberg or a ship.
81
- The model uses two radar bands (HH and HV polarization) to make its prediction.
82
- """)
83
-
84
- # Sidebar with information
85
- st.sidebar.header("About")
86
- st.sidebar.info("""
87
- This model was trained on the Statoil/C-CORE Iceberg Classifier Challenge dataset.
88
- It uses radar data from two different polarizations to distinguish between ships and icebergs.
89
- """)
90
-
91
- st.sidebar.header("Input Format")
92
- st.sidebar.info("""
93
- Each band should be a 75x75 array of radar backscatter values in dB.
94
- Values are typically between -50 and 50.
95
- """)
96
-
97
- # Main content
98
- st.subheader('Radar Image Data')
99
- col1, col2 = st.columns(2)
100
-
101
- with col1:
102
- band1_text = st.text_area('Band 1 (HH Polarization)', height=150,
103
- help='Enter 75x75 array in JSON format')
104
-
105
- with col2:
106
- band2_text = st.text_area('Band 2 (HV Polarization)', height=150,
107
- help='Enter 75x75 array in JSON format')
108
-
109
- if st.button('πŸ” Predict'):
110
- if not band1_text or not band2_text:
111
- st.warning('Please enter data for both bands')
112
- return
113
-
114
- try:
115
- # Parse JSON data
116
- band1_data = json.loads(band1_text)
117
- band2_data = json.loads(band2_text)
118
-
119
- # Validate array dimensions
120
- if (len(band1_data) != 75 or len(band1_data[0]) != 75 or
121
- len(band2_data) != 75 or len(band2_data[0]) != 75):
122
- st.error('Arrays must be 75x75 dimensions')
123
- return
124
-
125
- # Make prediction
126
- with st.spinner('Making prediction...'):
127
- probability = predict(band1_data, band2_data)
128
-
129
- if probability is not None:
130
- # Show results
131
- st.subheader('Prediction Result')
132
-
133
- # Create columns for the result display
134
- result_col1, result_col2 = st.columns(2)
135
-
136
- with result_col1:
137
- st.metric("Iceberg Probability", f"{probability:.2%}")
138
-
139
- with result_col2:
140
- if probability > 0.5:
141
- st.success('🧊 ICEBERG')
142
- else:
143
- st.success('🚒 SHIP')
144
-
145
- # Progress bar
146
- st.progress(probability)
147
-
148
- # Confidence message
149
- confidence = abs(probability - 0.5) * 2
150
- if confidence > 0.8:
151
- st.write("High confidence prediction")
152
- elif confidence > 0.4:
153
- st.write("Medium confidence prediction")
154
- else:
155
- st.write("Low confidence prediction")
156
-
157
- except json.JSONDecodeError:
158
- st.error('Please enter valid JSON format data')
159
- except Exception as e:
160
- st.error(f'An error occurred: {str(e)}')
161
-
162
- # Example usage
163
- with st.expander("See example input"):
164
- st.code('''
165
- # Example data format for each band:
166
- [
167
- [-32.5, -31.2, -30.8, ...], # 75 values
168
- [-31.8, -30.9, -31.1, ...], # 75 values
169
- ... # 75 rows total
170
- ]
171
- ''')
172
-
173
- # Footer
174
- st.markdown('---')
175
- st.markdown("""
176
- <div style='text-align: center'>
177
- <p>Made with ❀️ using Streamlit |
178
- <a href='https://huggingface.co/alperugurcan/iceberg'>Model on HuggingFace</a></p>
179
- </div>
180
- """, unsafe_allow_html=True)
181
-
182
- if __name__ == "__main__":
183
- main()
 
1
  import streamlit as st
2
+ import torch, torch.nn as nn
 
3
  import numpy as np
 
 
4
  from huggingface_hub import hf_hub_download
5
 
 
6
  class IcebergClassifier(nn.Module):
7
  def __init__(self):
8
  super().__init__()
9
+ self.conv = nn.Sequential(nn.Conv2d(2,16,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
10
+ nn.Conv2d(16,32,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
11
+ nn.Conv2d(32,64,3,padding=1),nn.ReLU(),nn.MaxPool2d(2))
12
+ self.fc = nn.Sequential(nn.Linear(64*9*9,64),nn.ReLU(),nn.Dropout(0.5),
13
+ nn.Linear(64,1),nn.Sigmoid())
14
+ def forward(self, x): return self.fc(self.conv(x).view(x.size(0),-1))
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
16
  @st.cache_resource
17
+ def predict(b1,b2):
18
+ model = IcebergClassifier().eval()
19
+ model.load_state_dict(torch.load(hf_hub_download("alperugurcan/iceberg","best_iceberg_model.pth"),map_location='cpu'))
20
+ return model(torch.FloatTensor(np.stack([np.array(b1).reshape(75,75),np.array(b2).reshape(75,75)])).unsqueeze(0)).item()
21
+
22
+ st.title('🧊 Iceberg vs Ship Classifier')
23
+ examples = {"Ship":[[-32.5]*75]*75,"Iceberg":[[-28.3]*75]*75,"Strong":[[-35.8]*75]*75}
24
+ for k,v in examples.items(): st.code(f"Example ({k}): {v}")
25
+ b1,b2 = st.text_area('Band 1'),st.text_area('Band 2')
26
+ if st.button('Predict'):
 
 
 
 
 
 
 
 
 
 
 
27
  try:
28
+ p = predict(eval(b1),eval(b2))
29
+ st.write(f"{'🧊 ICEBERG' if p>0.5 else '🚒 SHIP'} ({p:.1%})")
30
+ except: st.error('Invalid input')