Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import json | |
import requests | |
from huggingface_hub import hf_hub_download | |
# Model Architecture | |
class IcebergClassifier(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(2, 16, 3, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Conv2d(16, 32, 3, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Conv2d(32, 64, 3, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(2) | |
) | |
self.fc = nn.Sequential( | |
nn.Linear(64 * 9 * 9, 64), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(64, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
x = x.view(x.size(0), -1) | |
return self.fc(x) | |
# Model loading with HuggingFace Hub | |
def load_model(): | |
try: | |
# Download the model from HuggingFace Hub | |
model_path = hf_hub_download( | |
repo_id="alperugurcan/iceberg", | |
filename="best_iceberg_model.pth" | |
) | |
model = IcebergClassifier() | |
model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
model.eval() | |
return model | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return None | |
def predict(band1, band2): | |
model = load_model() | |
if model is None: | |
return None | |
try: | |
# Transform images to model input format | |
band1 = np.array(band1).reshape(75, 75) | |
band2 = np.array(band2).reshape(75, 75) | |
image = np.stack([band1, band2]) | |
image = torch.FloatTensor(image).unsqueeze(0) | |
# Prediction | |
with torch.no_grad(): | |
pred = model(image) | |
return pred.item() | |
except Exception as e: | |
st.error(f"Error during prediction: {str(e)}") | |
return None | |
# Streamlit UI | |
def main(): | |
st.title('π§ Iceberg vs Ship Classifier') | |
st.write(""" | |
This application uses satellite radar data to predict whether an image shows an iceberg or a ship. | |
The model uses two radar bands (HH and HV polarization) to make its prediction. | |
""") | |
# Sidebar with information | |
st.sidebar.header("About") | |
st.sidebar.info(""" | |
This model was trained on the Statoil/C-CORE Iceberg Classifier Challenge dataset. | |
It uses radar data from two different polarizations to distinguish between ships and icebergs. | |
""") | |
st.sidebar.header("Input Format") | |
st.sidebar.info(""" | |
Each band should be a 75x75 array of radar backscatter values in dB. | |
Values are typically between -50 and 50. | |
""") | |
# Main content | |
st.subheader('Radar Image Data') | |
col1, col2 = st.columns(2) | |
with col1: | |
band1_text = st.text_area('Band 1 (HH Polarization)', height=150, | |
help='Enter 75x75 array in JSON format') | |
with col2: | |
band2_text = st.text_area('Band 2 (HV Polarization)', height=150, | |
help='Enter 75x75 array in JSON format') | |
if st.button('π Predict'): | |
if not band1_text or not band2_text: | |
st.warning('Please enter data for both bands') | |
return | |
try: | |
# Parse JSON data | |
band1_data = json.loads(band1_text) | |
band2_data = json.loads(band2_text) | |
# Validate array dimensions | |
if (len(band1_data) != 75 or len(band1_data[0]) != 75 or | |
len(band2_data) != 75 or len(band2_data[0]) != 75): | |
st.error('Arrays must be 75x75 dimensions') | |
return | |
# Make prediction | |
with st.spinner('Making prediction...'): | |
probability = predict(band1_data, band2_data) | |
if probability is not None: | |
# Show results | |
st.subheader('Prediction Result') | |
# Create columns for the result display | |
result_col1, result_col2 = st.columns(2) | |
with result_col1: | |
st.metric("Iceberg Probability", f"{probability:.2%}") | |
with result_col2: | |
if probability > 0.5: | |
st.success('π§ ICEBERG') | |
else: | |
st.success('π’ SHIP') | |
# Progress bar | |
st.progress(probability) | |
# Confidence message | |
confidence = abs(probability - 0.5) * 2 | |
if confidence > 0.8: | |
st.write("High confidence prediction") | |
elif confidence > 0.4: | |
st.write("Medium confidence prediction") | |
else: | |
st.write("Low confidence prediction") | |
except json.JSONDecodeError: | |
st.error('Please enter valid JSON format data') | |
except Exception as e: | |
st.error(f'An error occurred: {str(e)}') | |
# Example usage | |
with st.expander("See example input"): | |
st.code(''' | |
# Example data format for each band: | |
[ | |
[-32.5, -31.2, -30.8, ...], # 75 values | |
[-31.8, -30.9, -31.1, ...], # 75 values | |
... # 75 rows total | |
] | |
''') | |
# Footer | |
st.markdown('---') | |
st.markdown(""" | |
<div style='text-align: center'> | |
<p>Made with β€οΈ using Streamlit | | |
<a href='https://huggingface.co/alperugurcan/iceberg'>Model on HuggingFace</a></p> | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |