Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch, torch.nn as nn | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
class IcebergClassifier(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = nn.Sequential(nn.Conv2d(2,16,3,padding=1),nn.ReLU(),nn.MaxPool2d(2), | |
nn.Conv2d(16,32,3,padding=1),nn.ReLU(),nn.MaxPool2d(2), | |
nn.Conv2d(32,64,3,padding=1),nn.ReLU(),nn.MaxPool2d(2)) | |
self.fc = nn.Sequential(nn.Linear(64*9*9,64),nn.ReLU(),nn.Dropout(0.5), | |
nn.Linear(64,1),nn.Sigmoid()) | |
def forward(self, x): return self.fc(self.conv(x).view(x.size(0),-1)) | |
def predict(b1,b2): | |
model = IcebergClassifier().eval() | |
model.load_state_dict(torch.load(hf_hub_download("alperugurcan/iceberg","best_iceberg_model.pth"),map_location='cpu')) | |
return model(torch.FloatTensor(np.stack([np.array(b1).reshape(75,75),np.array(b2).reshape(75,75)])).unsqueeze(0)).item() | |
st.title('π§ Iceberg vs Ship Classifier') | |
examples = {"Ship":[[-32.5]*75]*75,"Iceberg":[[-28.3]*75]*75,"Strong":[[-35.8]*75]*75} | |
for k,v in examples.items(): st.code(f"Example ({k}): {v}") | |
b1,b2 = st.text_area('Band 1'),st.text_area('Band 2') | |
if st.button('Predict'): | |
try: | |
p = predict(eval(b1),eval(b2)) | |
st.write(f"{'π§ ICEBERG' if p>0.5 else 'π’ SHIP'} ({p:.1%})") | |
except: st.error('Invalid input') |