alperugurcan commited on
Commit
ab484fd
Β·
verified Β·
1 Parent(s): 8689b43

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import json
6
+ import requests
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # Model Architecture
10
+ class IcebergClassifier(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.conv = nn.Sequential(
14
+ nn.Conv2d(2, 16, 3, padding=1),
15
+ nn.ReLU(),
16
+ nn.MaxPool2d(2),
17
+ nn.Conv2d(16, 32, 3, padding=1),
18
+ nn.ReLU(),
19
+ nn.MaxPool2d(2),
20
+ nn.Conv2d(32, 64, 3, padding=1),
21
+ nn.ReLU(),
22
+ nn.MaxPool2d(2)
23
+ )
24
+ self.fc = nn.Sequential(
25
+ nn.Linear(64 * 9 * 9, 64),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.5),
28
+ nn.Linear(64, 1),
29
+ nn.Sigmoid()
30
+ )
31
+
32
+ def forward(self, x):
33
+ x = self.conv(x)
34
+ x = x.view(x.size(0), -1)
35
+ return self.fc(x)
36
+
37
+ # Model loading with HuggingFace Hub
38
+ @st.cache_resource
39
+ def load_model():
40
+ try:
41
+ # Download the model from HuggingFace Hub
42
+ model_path = hf_hub_download(
43
+ repo_id="alperugurcan/iceberg",
44
+ filename="best_iceberg_model.pth"
45
+ )
46
+
47
+ model = IcebergClassifier()
48
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
49
+ model.eval()
50
+ return model
51
+ except Exception as e:
52
+ st.error(f"Error loading model: {str(e)}")
53
+ return None
54
+
55
+ def predict(band1, band2):
56
+ model = load_model()
57
+ if model is None:
58
+ return None
59
+
60
+ try:
61
+ # Transform images to model input format
62
+ band1 = np.array(band1).reshape(75, 75)
63
+ band2 = np.array(band2).reshape(75, 75)
64
+ image = np.stack([band1, band2])
65
+ image = torch.FloatTensor(image).unsqueeze(0)
66
+
67
+ # Prediction
68
+ with torch.no_grad():
69
+ pred = model(image)
70
+
71
+ return pred.item()
72
+ except Exception as e:
73
+ st.error(f"Error during prediction: {str(e)}")
74
+ return None
75
+
76
+ # Streamlit UI
77
+ def main():
78
+ st.title('🧊 Iceberg vs Ship Classifier')
79
+ st.write("""
80
+ This application uses satellite radar data to predict whether an image shows an iceberg or a ship.
81
+ The model uses two radar bands (HH and HV polarization) to make its prediction.
82
+ """)
83
+
84
+ # Sidebar with information
85
+ st.sidebar.header("About")
86
+ st.sidebar.info("""
87
+ This model was trained on the Statoil/C-CORE Iceberg Classifier Challenge dataset.
88
+ It uses radar data from two different polarizations to distinguish between ships and icebergs.
89
+ """)
90
+
91
+ st.sidebar.header("Input Format")
92
+ st.sidebar.info("""
93
+ Each band should be a 75x75 array of radar backscatter values in dB.
94
+ Values are typically between -50 and 50.
95
+ """)
96
+
97
+ # Main content
98
+ st.subheader('Radar Image Data')
99
+ col1, col2 = st.columns(2)
100
+
101
+ with col1:
102
+ band1_text = st.text_area('Band 1 (HH Polarization)', height=150,
103
+ help='Enter 75x75 array in JSON format')
104
+
105
+ with col2:
106
+ band2_text = st.text_area('Band 2 (HV Polarization)', height=150,
107
+ help='Enter 75x75 array in JSON format')
108
+
109
+ if st.button('πŸ” Predict'):
110
+ if not band1_text or not band2_text:
111
+ st.warning('Please enter data for both bands')
112
+ return
113
+
114
+ try:
115
+ # Parse JSON data
116
+ band1_data = json.loads(band1_text)
117
+ band2_data = json.loads(band2_text)
118
+
119
+ # Validate array dimensions
120
+ if (len(band1_data) != 75 or len(band1_data[0]) != 75 or
121
+ len(band2_data) != 75 or len(band2_data[0]) != 75):
122
+ st.error('Arrays must be 75x75 dimensions')
123
+ return
124
+
125
+ # Make prediction
126
+ with st.spinner('Making prediction...'):
127
+ probability = predict(band1_data, band2_data)
128
+
129
+ if probability is not None:
130
+ # Show results
131
+ st.subheader('Prediction Result')
132
+
133
+ # Create columns for the result display
134
+ result_col1, result_col2 = st.columns(2)
135
+
136
+ with result_col1:
137
+ st.metric("Iceberg Probability", f"{probability:.2%}")
138
+
139
+ with result_col2:
140
+ if probability > 0.5:
141
+ st.success('🧊 ICEBERG')
142
+ else:
143
+ st.success('🚒 SHIP')
144
+
145
+ # Progress bar
146
+ st.progress(probability)
147
+
148
+ # Confidence message
149
+ confidence = abs(probability - 0.5) * 2
150
+ if confidence > 0.8:
151
+ st.write("High confidence prediction")
152
+ elif confidence > 0.4:
153
+ st.write("Medium confidence prediction")
154
+ else:
155
+ st.write("Low confidence prediction")
156
+
157
+ except json.JSONDecodeError:
158
+ st.error('Please enter valid JSON format data')
159
+ except Exception as e:
160
+ st.error(f'An error occurred: {str(e)}')
161
+
162
+ # Example usage
163
+ with st.expander("See example input"):
164
+ st.code('''
165
+ # Example data format for each band:
166
+ [
167
+ [-32.5, -31.2, -30.8, ...], # 75 values
168
+ [-31.8, -30.9, -31.1, ...], # 75 values
169
+ ... # 75 rows total
170
+ ]
171
+ ''')
172
+
173
+ # Footer
174
+ st.markdown('---')
175
+ st.markdown("""
176
+ <div style='text-align: center'>
177
+ <p>Made with ❀️ using Streamlit |
178
+ <a href='https://huggingface.co/alperugurcan/iceberg'>Model on HuggingFace</a></p>
179
+ </div>
180
+ """, unsafe_allow_html=True)
181
+
182
+ if __name__ == "__main__":
183
+ main()