|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms, models |
|
from PIL import Image, UnidentifiedImageError |
|
import streamlit as st |
|
import numpy as np |
|
import requests |
|
from io import BytesIO |
|
from kan_linear import KANLinear |
|
import logging |
|
import os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
class CustomResNetKAN(nn.Module): |
|
def __init__(self, num_classes=1): |
|
super(CustomResNetKAN, self).__init__() |
|
self.model = models.resnet50(pretrained=False) |
|
self.model.fc = KANLinear(self.model.fc.in_features, num_classes) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def load_model(weights_path, device): |
|
model = CustomResNetKAN().to(device) |
|
state_dict = torch.load(weights_path, map_location=device) |
|
|
|
|
|
new_state_dict = {} |
|
for k, v in state_dict.items(): |
|
if k.startswith('module.'): |
|
new_state_dict[k[len('module.'):]] = v |
|
else: |
|
new_state_dict[k] = v |
|
|
|
model.load_state_dict(new_state_dict) |
|
model.eval() |
|
return model |
|
|
|
class CustomImageLoadingError(Exception): |
|
"""Custom exception for image loading errors""" |
|
pass |
|
|
|
def load_image_from_url(url): |
|
try: |
|
logging.info(f"Loading image from URL: {url}") |
|
|
|
|
|
valid_extensions = ['jpg', 'jpeg', 'png', 'webp'] |
|
file_extension = os.path.splitext(url)[1][1:].lower() |
|
if file_extension not in valid_extensions: |
|
raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}") |
|
|
|
response = requests.get(url) |
|
response.raise_for_status() |
|
|
|
content_type = response.headers['Content-Type'] |
|
logging.info(f"Content-Type: {content_type}") |
|
|
|
|
|
if 'image' not in content_type: |
|
raise CustomImageLoadingError(f"URL does not point to an image: {content_type}") |
|
|
|
img = Image.open(BytesIO(response.content)).convert('RGB') |
|
logging.info("Image successfully loaded and converted to RGB") |
|
return img |
|
except requests.HTTPError as e: |
|
logging.error(f"HTTPError while loading image: {e}") |
|
raise CustomImageLoadingError(f"Error loading image from URL: {e}") |
|
except UnidentifiedImageError as e: |
|
logging.error(f"UnidentifiedImageError while loading image: {e}") |
|
raise CustomImageLoadingError(f"Cannot identify image file: {e}") |
|
except requests.RequestException as e: |
|
logging.error(f"RequestException while loading image: {e}") |
|
raise CustomImageLoadingError(f"Error loading image from URL: {e}") |
|
except Exception as e: |
|
logging.error(f"Unexpected error while loading image: {e}") |
|
raise CustomImageLoadingError(f"Error loading image from URL: {e}") |
|
|
|
def preprocess_image(image): |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor() |
|
]) |
|
return transform(image).unsqueeze(0) |
|
|
|
|
|
st.title("Cat and Dog Classification with ResNet-KAN") |
|
|
|
st.sidebar.title("Upload Images") |
|
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"]) |
|
image_url = st.sidebar.text_input("Or enter image URL...") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = load_model('weights/best_model_resnet50_KAN.pth', device) |
|
|
|
img = None |
|
|
|
if uploaded_file is not None: |
|
logging.info("Image uploaded via file uploader") |
|
img = Image.open(uploaded_file).convert('RGB') |
|
elif image_url: |
|
try: |
|
img = load_image_from_url(image_url) |
|
except CustomImageLoadingError as e: |
|
st.sidebar.error(str(e)) |
|
except Exception as e: |
|
st.sidebar.error(f"Unexpected error: {e}") |
|
|
|
st.sidebar.write("-----") |
|
|
|
|
|
name = "Wayan Dadang" |
|
|
|
st.sidebar.write("Follow me on:") |
|
|
|
st.sidebar.markdown(f""" |
|
[LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/) |
|
[GitHub](https://github.com/Wayan123) |
|
[Resume](https://wayan123.github.io/) |
|
© {name} - {2024} |
|
""", unsafe_allow_html=True) |
|
|
|
if img is not None: |
|
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True) |
|
if st.button('Predict'): |
|
img_tensor = preprocess_image(img).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(img_tensor) |
|
prob = torch.sigmoid(output).item() |
|
|
|
st.write(f"Prediction: {prob:.4f}") |
|
|
|
if prob < 0.5: |
|
st.write("This image is classified as a Cat.") |
|
else: |
|
st.write("This image is classified as a Dog.") |
|
|