File size: 5,806 Bytes
ab484fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import json
import requests
from huggingface_hub import hf_hub_download

# Model Architecture
class IcebergClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(2, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 9 * 9, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Model loading with HuggingFace Hub
@st.cache_resource
def load_model():
    try:
        # Download the model from HuggingFace Hub
        model_path = hf_hub_download(
            repo_id="alperugurcan/iceberg",
            filename="best_iceberg_model.pth"
        )
        
        model = IcebergClassifier()
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        model.eval()
        return model
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None

def predict(band1, band2):
    model = load_model()
    if model is None:
        return None
    
    try:
        # Transform images to model input format
        band1 = np.array(band1).reshape(75, 75)
        band2 = np.array(band2).reshape(75, 75)
        image = np.stack([band1, band2])
        image = torch.FloatTensor(image).unsqueeze(0)
        
        # Prediction
        with torch.no_grad():
            pred = model(image)
        
        return pred.item()
    except Exception as e:
        st.error(f"Error during prediction: {str(e)}")
        return None

# Streamlit UI
def main():
    st.title('🧊 Iceberg vs Ship Classifier')
    st.write("""
    This application uses satellite radar data to predict whether an image shows an iceberg or a ship.
    The model uses two radar bands (HH and HV polarization) to make its prediction.
    """)

    # Sidebar with information
    st.sidebar.header("About")
    st.sidebar.info("""
    This model was trained on the Statoil/C-CORE Iceberg Classifier Challenge dataset.
    It uses radar data from two different polarizations to distinguish between ships and icebergs.
    """)
    
    st.sidebar.header("Input Format")
    st.sidebar.info("""
    Each band should be a 75x75 array of radar backscatter values in dB.
    Values are typically between -50 and 50.
    """)

    # Main content
    st.subheader('Radar Image Data')
    col1, col2 = st.columns(2)
    
    with col1:
        band1_text = st.text_area('Band 1 (HH Polarization)', height=150,
                                 help='Enter 75x75 array in JSON format')
    
    with col2:
        band2_text = st.text_area('Band 2 (HV Polarization)', height=150,
                                 help='Enter 75x75 array in JSON format')

    if st.button('πŸ” Predict'):
        if not band1_text or not band2_text:
            st.warning('Please enter data for both bands')
            return
            
        try:
            # Parse JSON data
            band1_data = json.loads(band1_text)
            band2_data = json.loads(band2_text)
            
            # Validate array dimensions
            if (len(band1_data) != 75 or len(band1_data[0]) != 75 or 
                len(band2_data) != 75 or len(band2_data[0]) != 75):
                st.error('Arrays must be 75x75 dimensions')
                return
            
            # Make prediction
            with st.spinner('Making prediction...'):
                probability = predict(band1_data, band2_data)
            
            if probability is not None:
                # Show results
                st.subheader('Prediction Result')
                
                # Create columns for the result display
                result_col1, result_col2 = st.columns(2)
                
                with result_col1:
                    st.metric("Iceberg Probability", f"{probability:.2%}")
                
                with result_col2:
                    if probability > 0.5:
                        st.success('🧊 ICEBERG')
                    else:
                        st.success('🚒 SHIP')
                
                # Progress bar
                st.progress(probability)
                
                # Confidence message
                confidence = abs(probability - 0.5) * 2
                if confidence > 0.8:
                    st.write("High confidence prediction")
                elif confidence > 0.4:
                    st.write("Medium confidence prediction")
                else:
                    st.write("Low confidence prediction")
                
        except json.JSONDecodeError:
            st.error('Please enter valid JSON format data')
        except Exception as e:
            st.error(f'An error occurred: {str(e)}')

    # Example usage
    with st.expander("See example input"):
        st.code('''
# Example data format for each band:
[
    [-32.5, -31.2, -30.8, ...],  # 75 values
    [-31.8, -30.9, -31.1, ...],  # 75 values
    ...                          # 75 rows total
]
        ''')

    # Footer
    st.markdown('---')
    st.markdown("""
    <div style='text-align: center'>
        <p>Made with ❀️ using Streamlit | 
        <a href='https://huggingface.co/alperugurcan/iceberg'>Model on HuggingFace</a></p>
    </div>
    """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()