Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,183 +1,30 @@
|
|
1 |
import streamlit as st
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
import numpy as np
|
5 |
-
import json
|
6 |
-
import requests
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
-
# Model Architecture
|
10 |
class IcebergClassifier(nn.Module):
|
11 |
def __init__(self):
|
12 |
super().__init__()
|
13 |
-
self.conv = nn.Sequential(
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
nn.MaxPool2d(2),
|
20 |
-
nn.Conv2d(32, 64, 3, padding=1),
|
21 |
-
nn.ReLU(),
|
22 |
-
nn.MaxPool2d(2)
|
23 |
-
)
|
24 |
-
self.fc = nn.Sequential(
|
25 |
-
nn.Linear(64 * 9 * 9, 64),
|
26 |
-
nn.ReLU(),
|
27 |
-
nn.Dropout(0.5),
|
28 |
-
nn.Linear(64, 1),
|
29 |
-
nn.Sigmoid()
|
30 |
-
)
|
31 |
|
32 |
-
def forward(self, x):
|
33 |
-
x = self.conv(x)
|
34 |
-
x = x.view(x.size(0), -1)
|
35 |
-
return self.fc(x)
|
36 |
-
|
37 |
-
# Model loading with HuggingFace Hub
|
38 |
@st.cache_resource
|
39 |
-
def
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
model.eval()
|
50 |
-
return model
|
51 |
-
except Exception as e:
|
52 |
-
st.error(f"Error loading model: {str(e)}")
|
53 |
-
return None
|
54 |
-
|
55 |
-
def predict(band1, band2):
|
56 |
-
model = load_model()
|
57 |
-
if model is None:
|
58 |
-
return None
|
59 |
-
|
60 |
try:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
image = np.stack([band1, band2])
|
65 |
-
image = torch.FloatTensor(image).unsqueeze(0)
|
66 |
-
|
67 |
-
# Prediction
|
68 |
-
with torch.no_grad():
|
69 |
-
pred = model(image)
|
70 |
-
|
71 |
-
return pred.item()
|
72 |
-
except Exception as e:
|
73 |
-
st.error(f"Error during prediction: {str(e)}")
|
74 |
-
return None
|
75 |
-
|
76 |
-
# Streamlit UI
|
77 |
-
def main():
|
78 |
-
st.title('π§ Iceberg vs Ship Classifier')
|
79 |
-
st.write("""
|
80 |
-
This application uses satellite radar data to predict whether an image shows an iceberg or a ship.
|
81 |
-
The model uses two radar bands (HH and HV polarization) to make its prediction.
|
82 |
-
""")
|
83 |
-
|
84 |
-
# Sidebar with information
|
85 |
-
st.sidebar.header("About")
|
86 |
-
st.sidebar.info("""
|
87 |
-
This model was trained on the Statoil/C-CORE Iceberg Classifier Challenge dataset.
|
88 |
-
It uses radar data from two different polarizations to distinguish between ships and icebergs.
|
89 |
-
""")
|
90 |
-
|
91 |
-
st.sidebar.header("Input Format")
|
92 |
-
st.sidebar.info("""
|
93 |
-
Each band should be a 75x75 array of radar backscatter values in dB.
|
94 |
-
Values are typically between -50 and 50.
|
95 |
-
""")
|
96 |
-
|
97 |
-
# Main content
|
98 |
-
st.subheader('Radar Image Data')
|
99 |
-
col1, col2 = st.columns(2)
|
100 |
-
|
101 |
-
with col1:
|
102 |
-
band1_text = st.text_area('Band 1 (HH Polarization)', height=150,
|
103 |
-
help='Enter 75x75 array in JSON format')
|
104 |
-
|
105 |
-
with col2:
|
106 |
-
band2_text = st.text_area('Band 2 (HV Polarization)', height=150,
|
107 |
-
help='Enter 75x75 array in JSON format')
|
108 |
-
|
109 |
-
if st.button('π Predict'):
|
110 |
-
if not band1_text or not band2_text:
|
111 |
-
st.warning('Please enter data for both bands')
|
112 |
-
return
|
113 |
-
|
114 |
-
try:
|
115 |
-
# Parse JSON data
|
116 |
-
band1_data = json.loads(band1_text)
|
117 |
-
band2_data = json.loads(band2_text)
|
118 |
-
|
119 |
-
# Validate array dimensions
|
120 |
-
if (len(band1_data) != 75 or len(band1_data[0]) != 75 or
|
121 |
-
len(band2_data) != 75 or len(band2_data[0]) != 75):
|
122 |
-
st.error('Arrays must be 75x75 dimensions')
|
123 |
-
return
|
124 |
-
|
125 |
-
# Make prediction
|
126 |
-
with st.spinner('Making prediction...'):
|
127 |
-
probability = predict(band1_data, band2_data)
|
128 |
-
|
129 |
-
if probability is not None:
|
130 |
-
# Show results
|
131 |
-
st.subheader('Prediction Result')
|
132 |
-
|
133 |
-
# Create columns for the result display
|
134 |
-
result_col1, result_col2 = st.columns(2)
|
135 |
-
|
136 |
-
with result_col1:
|
137 |
-
st.metric("Iceberg Probability", f"{probability:.2%}")
|
138 |
-
|
139 |
-
with result_col2:
|
140 |
-
if probability > 0.5:
|
141 |
-
st.success('π§ ICEBERG')
|
142 |
-
else:
|
143 |
-
st.success('π’ SHIP')
|
144 |
-
|
145 |
-
# Progress bar
|
146 |
-
st.progress(probability)
|
147 |
-
|
148 |
-
# Confidence message
|
149 |
-
confidence = abs(probability - 0.5) * 2
|
150 |
-
if confidence > 0.8:
|
151 |
-
st.write("High confidence prediction")
|
152 |
-
elif confidence > 0.4:
|
153 |
-
st.write("Medium confidence prediction")
|
154 |
-
else:
|
155 |
-
st.write("Low confidence prediction")
|
156 |
-
|
157 |
-
except json.JSONDecodeError:
|
158 |
-
st.error('Please enter valid JSON format data')
|
159 |
-
except Exception as e:
|
160 |
-
st.error(f'An error occurred: {str(e)}')
|
161 |
-
|
162 |
-
# Example usage
|
163 |
-
with st.expander("See example input"):
|
164 |
-
st.code('''
|
165 |
-
# Example data format for each band:
|
166 |
-
[
|
167 |
-
[-32.5, -31.2, -30.8, ...], # 75 values
|
168 |
-
[-31.8, -30.9, -31.1, ...], # 75 values
|
169 |
-
... # 75 rows total
|
170 |
-
]
|
171 |
-
''')
|
172 |
-
|
173 |
-
# Footer
|
174 |
-
st.markdown('---')
|
175 |
-
st.markdown("""
|
176 |
-
<div style='text-align: center'>
|
177 |
-
<p>Made with β€οΈ using Streamlit |
|
178 |
-
<a href='https://huggingface.co/alperugurcan/iceberg'>Model on HuggingFace</a></p>
|
179 |
-
</div>
|
180 |
-
""", unsafe_allow_html=True)
|
181 |
-
|
182 |
-
if __name__ == "__main__":
|
183 |
-
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch, torch.nn as nn
|
|
|
3 |
import numpy as np
|
|
|
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
|
|
|
6 |
class IcebergClassifier(nn.Module):
|
7 |
def __init__(self):
|
8 |
super().__init__()
|
9 |
+
self.conv = nn.Sequential(nn.Conv2d(2,16,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
|
10 |
+
nn.Conv2d(16,32,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
|
11 |
+
nn.Conv2d(32,64,3,padding=1),nn.ReLU(),nn.MaxPool2d(2))
|
12 |
+
self.fc = nn.Sequential(nn.Linear(64*9*9,64),nn.ReLU(),nn.Dropout(0.5),
|
13 |
+
nn.Linear(64,1),nn.Sigmoid())
|
14 |
+
def forward(self, x): return self.fc(self.conv(x).view(x.size(0),-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
@st.cache_resource
|
17 |
+
def predict(b1,b2):
|
18 |
+
model = IcebergClassifier().eval()
|
19 |
+
model.load_state_dict(torch.load(hf_hub_download("alperugurcan/iceberg","best_iceberg_model.pth"),map_location='cpu'))
|
20 |
+
return model(torch.FloatTensor(np.stack([np.array(b1).reshape(75,75),np.array(b2).reshape(75,75)])).unsqueeze(0)).item()
|
21 |
+
|
22 |
+
st.title('π§ Iceberg vs Ship Classifier')
|
23 |
+
examples = {"Ship":[[-32.5]*75]*75,"Iceberg":[[-28.3]*75]*75,"Strong":[[-35.8]*75]*75}
|
24 |
+
for k,v in examples.items(): st.code(f"Example ({k}): {v}")
|
25 |
+
b1,b2 = st.text_area('Band 1'),st.text_area('Band 2')
|
26 |
+
if st.button('Predict'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
try:
|
28 |
+
p = predict(eval(b1),eval(b2))
|
29 |
+
st.write(f"{'π§ ICEBERG' if p>0.5 else 'π’ SHIP'} ({p:.1%})")
|
30 |
+
except: st.error('Invalid input')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|