alperugurcan's picture
Create app.py
ab484fd verified
raw
history blame
5.81 kB
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import json
import requests
from huggingface_hub import hf_hub_download
# Model Architecture
class IcebergClassifier(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(2, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Linear(64 * 9 * 9, 64),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Model loading with HuggingFace Hub
@st.cache_resource
def load_model():
try:
# Download the model from HuggingFace Hub
model_path = hf_hub_download(
repo_id="alperugurcan/iceberg",
filename="best_iceberg_model.pth"
)
model = IcebergClassifier()
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
def predict(band1, band2):
model = load_model()
if model is None:
return None
try:
# Transform images to model input format
band1 = np.array(band1).reshape(75, 75)
band2 = np.array(band2).reshape(75, 75)
image = np.stack([band1, band2])
image = torch.FloatTensor(image).unsqueeze(0)
# Prediction
with torch.no_grad():
pred = model(image)
return pred.item()
except Exception as e:
st.error(f"Error during prediction: {str(e)}")
return None
# Streamlit UI
def main():
st.title('🧊 Iceberg vs Ship Classifier')
st.write("""
This application uses satellite radar data to predict whether an image shows an iceberg or a ship.
The model uses two radar bands (HH and HV polarization) to make its prediction.
""")
# Sidebar with information
st.sidebar.header("About")
st.sidebar.info("""
This model was trained on the Statoil/C-CORE Iceberg Classifier Challenge dataset.
It uses radar data from two different polarizations to distinguish between ships and icebergs.
""")
st.sidebar.header("Input Format")
st.sidebar.info("""
Each band should be a 75x75 array of radar backscatter values in dB.
Values are typically between -50 and 50.
""")
# Main content
st.subheader('Radar Image Data')
col1, col2 = st.columns(2)
with col1:
band1_text = st.text_area('Band 1 (HH Polarization)', height=150,
help='Enter 75x75 array in JSON format')
with col2:
band2_text = st.text_area('Band 2 (HV Polarization)', height=150,
help='Enter 75x75 array in JSON format')
if st.button('πŸ” Predict'):
if not band1_text or not band2_text:
st.warning('Please enter data for both bands')
return
try:
# Parse JSON data
band1_data = json.loads(band1_text)
band2_data = json.loads(band2_text)
# Validate array dimensions
if (len(band1_data) != 75 or len(band1_data[0]) != 75 or
len(band2_data) != 75 or len(band2_data[0]) != 75):
st.error('Arrays must be 75x75 dimensions')
return
# Make prediction
with st.spinner('Making prediction...'):
probability = predict(band1_data, band2_data)
if probability is not None:
# Show results
st.subheader('Prediction Result')
# Create columns for the result display
result_col1, result_col2 = st.columns(2)
with result_col1:
st.metric("Iceberg Probability", f"{probability:.2%}")
with result_col2:
if probability > 0.5:
st.success('🧊 ICEBERG')
else:
st.success('🚒 SHIP')
# Progress bar
st.progress(probability)
# Confidence message
confidence = abs(probability - 0.5) * 2
if confidence > 0.8:
st.write("High confidence prediction")
elif confidence > 0.4:
st.write("Medium confidence prediction")
else:
st.write("Low confidence prediction")
except json.JSONDecodeError:
st.error('Please enter valid JSON format data')
except Exception as e:
st.error(f'An error occurred: {str(e)}')
# Example usage
with st.expander("See example input"):
st.code('''
# Example data format for each band:
[
[-32.5, -31.2, -30.8, ...], # 75 values
[-31.8, -30.9, -31.1, ...], # 75 values
... # 75 rows total
]
''')
# Footer
st.markdown('---')
st.markdown("""
<div style='text-align: center'>
<p>Made with ❀️ using Streamlit |
<a href='https://huggingface.co/alperugurcan/iceberg'>Model on HuggingFace</a></p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()