alperugurcan's picture
Update app.py
778467b verified
raw
history blame
1.44 kB
import streamlit as st
import torch, torch.nn as nn
import numpy as np
from huggingface_hub import hf_hub_download
class IcebergClassifier(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(nn.Conv2d(2,16,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
nn.Conv2d(16,32,3,padding=1),nn.ReLU(),nn.MaxPool2d(2),
nn.Conv2d(32,64,3,padding=1),nn.ReLU(),nn.MaxPool2d(2))
self.fc = nn.Sequential(nn.Linear(64*9*9,64),nn.ReLU(),nn.Dropout(0.5),
nn.Linear(64,1),nn.Sigmoid())
def forward(self, x): return self.fc(self.conv(x).view(x.size(0),-1))
@st.cache_resource
def predict(b1,b2):
model = IcebergClassifier().eval()
model.load_state_dict(torch.load(hf_hub_download("alperugurcan/iceberg","best_iceberg_model.pth"),map_location='cpu'))
return model(torch.FloatTensor(np.stack([np.array(b1).reshape(75,75),np.array(b2).reshape(75,75)])).unsqueeze(0)).item()
st.title('🧊 Iceberg vs Ship Classifier')
examples = {"Ship":[[-32.5]*75]*75,"Iceberg":[[-28.3]*75]*75,"Strong":[[-35.8]*75]*75}
for k,v in examples.items(): st.code(f"Example ({k}): {v}")
b1,b2 = st.text_area('Band 1'),st.text_area('Band 2')
if st.button('Predict'):
try:
p = predict(eval(b1),eval(b2))
st.write(f"{'🧊 ICEBERG' if p>0.5 else '🚒 SHIP'} ({p:.1%})")
except: st.error('Invalid input')