Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import requests
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
# Model Architecture
|
10 |
+
class IcebergClassifier(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.conv = nn.Sequential(
|
14 |
+
nn.Conv2d(2, 16, 3, padding=1),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.MaxPool2d(2),
|
17 |
+
nn.Conv2d(16, 32, 3, padding=1),
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.MaxPool2d(2),
|
20 |
+
nn.Conv2d(32, 64, 3, padding=1),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.MaxPool2d(2)
|
23 |
+
)
|
24 |
+
self.fc = nn.Sequential(
|
25 |
+
nn.Linear(64 * 9 * 9, 64),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Dropout(0.5),
|
28 |
+
nn.Linear(64, 1),
|
29 |
+
nn.Sigmoid()
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.conv(x)
|
34 |
+
x = x.view(x.size(0), -1)
|
35 |
+
return self.fc(x)
|
36 |
+
|
37 |
+
# Model loading with HuggingFace Hub
|
38 |
+
@st.cache_resource
|
39 |
+
def load_model():
|
40 |
+
try:
|
41 |
+
# Download the model from HuggingFace Hub
|
42 |
+
model_path = hf_hub_download(
|
43 |
+
repo_id="alperugurcan/iceberg",
|
44 |
+
filename="best_iceberg_model.pth"
|
45 |
+
)
|
46 |
+
|
47 |
+
model = IcebergClassifier()
|
48 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
49 |
+
model.eval()
|
50 |
+
return model
|
51 |
+
except Exception as e:
|
52 |
+
st.error(f"Error loading model: {str(e)}")
|
53 |
+
return None
|
54 |
+
|
55 |
+
def predict(band1, band2):
|
56 |
+
model = load_model()
|
57 |
+
if model is None:
|
58 |
+
return None
|
59 |
+
|
60 |
+
try:
|
61 |
+
# Transform images to model input format
|
62 |
+
band1 = np.array(band1).reshape(75, 75)
|
63 |
+
band2 = np.array(band2).reshape(75, 75)
|
64 |
+
image = np.stack([band1, band2])
|
65 |
+
image = torch.FloatTensor(image).unsqueeze(0)
|
66 |
+
|
67 |
+
# Prediction
|
68 |
+
with torch.no_grad():
|
69 |
+
pred = model(image)
|
70 |
+
|
71 |
+
return pred.item()
|
72 |
+
except Exception as e:
|
73 |
+
st.error(f"Error during prediction: {str(e)}")
|
74 |
+
return None
|
75 |
+
|
76 |
+
# Streamlit UI
|
77 |
+
def main():
|
78 |
+
st.title('π§ Iceberg vs Ship Classifier')
|
79 |
+
st.write("""
|
80 |
+
This application uses satellite radar data to predict whether an image shows an iceberg or a ship.
|
81 |
+
The model uses two radar bands (HH and HV polarization) to make its prediction.
|
82 |
+
""")
|
83 |
+
|
84 |
+
# Sidebar with information
|
85 |
+
st.sidebar.header("About")
|
86 |
+
st.sidebar.info("""
|
87 |
+
This model was trained on the Statoil/C-CORE Iceberg Classifier Challenge dataset.
|
88 |
+
It uses radar data from two different polarizations to distinguish between ships and icebergs.
|
89 |
+
""")
|
90 |
+
|
91 |
+
st.sidebar.header("Input Format")
|
92 |
+
st.sidebar.info("""
|
93 |
+
Each band should be a 75x75 array of radar backscatter values in dB.
|
94 |
+
Values are typically between -50 and 50.
|
95 |
+
""")
|
96 |
+
|
97 |
+
# Main content
|
98 |
+
st.subheader('Radar Image Data')
|
99 |
+
col1, col2 = st.columns(2)
|
100 |
+
|
101 |
+
with col1:
|
102 |
+
band1_text = st.text_area('Band 1 (HH Polarization)', height=150,
|
103 |
+
help='Enter 75x75 array in JSON format')
|
104 |
+
|
105 |
+
with col2:
|
106 |
+
band2_text = st.text_area('Band 2 (HV Polarization)', height=150,
|
107 |
+
help='Enter 75x75 array in JSON format')
|
108 |
+
|
109 |
+
if st.button('π Predict'):
|
110 |
+
if not band1_text or not band2_text:
|
111 |
+
st.warning('Please enter data for both bands')
|
112 |
+
return
|
113 |
+
|
114 |
+
try:
|
115 |
+
# Parse JSON data
|
116 |
+
band1_data = json.loads(band1_text)
|
117 |
+
band2_data = json.loads(band2_text)
|
118 |
+
|
119 |
+
# Validate array dimensions
|
120 |
+
if (len(band1_data) != 75 or len(band1_data[0]) != 75 or
|
121 |
+
len(band2_data) != 75 or len(band2_data[0]) != 75):
|
122 |
+
st.error('Arrays must be 75x75 dimensions')
|
123 |
+
return
|
124 |
+
|
125 |
+
# Make prediction
|
126 |
+
with st.spinner('Making prediction...'):
|
127 |
+
probability = predict(band1_data, band2_data)
|
128 |
+
|
129 |
+
if probability is not None:
|
130 |
+
# Show results
|
131 |
+
st.subheader('Prediction Result')
|
132 |
+
|
133 |
+
# Create columns for the result display
|
134 |
+
result_col1, result_col2 = st.columns(2)
|
135 |
+
|
136 |
+
with result_col1:
|
137 |
+
st.metric("Iceberg Probability", f"{probability:.2%}")
|
138 |
+
|
139 |
+
with result_col2:
|
140 |
+
if probability > 0.5:
|
141 |
+
st.success('π§ ICEBERG')
|
142 |
+
else:
|
143 |
+
st.success('π’ SHIP')
|
144 |
+
|
145 |
+
# Progress bar
|
146 |
+
st.progress(probability)
|
147 |
+
|
148 |
+
# Confidence message
|
149 |
+
confidence = abs(probability - 0.5) * 2
|
150 |
+
if confidence > 0.8:
|
151 |
+
st.write("High confidence prediction")
|
152 |
+
elif confidence > 0.4:
|
153 |
+
st.write("Medium confidence prediction")
|
154 |
+
else:
|
155 |
+
st.write("Low confidence prediction")
|
156 |
+
|
157 |
+
except json.JSONDecodeError:
|
158 |
+
st.error('Please enter valid JSON format data')
|
159 |
+
except Exception as e:
|
160 |
+
st.error(f'An error occurred: {str(e)}')
|
161 |
+
|
162 |
+
# Example usage
|
163 |
+
with st.expander("See example input"):
|
164 |
+
st.code('''
|
165 |
+
# Example data format for each band:
|
166 |
+
[
|
167 |
+
[-32.5, -31.2, -30.8, ...], # 75 values
|
168 |
+
[-31.8, -30.9, -31.1, ...], # 75 values
|
169 |
+
... # 75 rows total
|
170 |
+
]
|
171 |
+
''')
|
172 |
+
|
173 |
+
# Footer
|
174 |
+
st.markdown('---')
|
175 |
+
st.markdown("""
|
176 |
+
<div style='text-align: center'>
|
177 |
+
<p>Made with β€οΈ using Streamlit |
|
178 |
+
<a href='https://huggingface.co/alperugurcan/iceberg'>Model on HuggingFace</a></p>
|
179 |
+
</div>
|
180 |
+
""", unsafe_allow_html=True)
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
main()
|