Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms, models | |
from PIL import Image, UnidentifiedImageError | |
import streamlit as st | |
import numpy as np | |
import requests | |
from io import BytesIO | |
from kan_linear import KANLinear | |
import logging | |
import os | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
# Define the model | |
class KANVGG16(nn.Module): | |
def __init__(self, num_classes=1): # For binary classification (cats and dogs) | |
super(KANVGG16, self).__init__() | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 64, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
nn.BatchNorm2d(64), # Added Batch Normalization | |
nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 128, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
nn.BatchNorm2d(128), # Added Batch Normalization | |
nn.Conv2d(128, 256, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 256, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 256, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
nn.BatchNorm2d(256), # Added Batch Normalization | |
nn.Conv2d(256, 512, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
nn.BatchNorm2d(512), # Added Batch Normalization | |
nn.Conv2d(512, 512, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
nn.BatchNorm2d(512), # Added Batch Normalization | |
) | |
self.classifier = nn.Sequential( | |
KANLinear(512 * 7 * 7, 2048), # Adjusted for input size 224x224 | |
nn.ReLU(inplace=True), | |
nn.Dropout(0.5), # Increased Dropout | |
KANLinear(2048, 2048), | |
nn.ReLU(inplace=True), | |
nn.Dropout(0.5), # Increased Dropout | |
KANLinear(2048, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = torch.flatten(x, 1) | |
x = self.classifier(x) | |
return x | |
def load_model(weights_path, device): | |
model = KANVGG16().to(device) | |
state_dict = torch.load(weights_path, map_location=device) | |
# Remove 'module.' prefix from keys | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if k.startswith('module.'): | |
new_state_dict[k[len('module.'):]] = v | |
else: | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
return model | |
class CustomImageLoadingError(Exception): | |
"""Custom exception for image loading errors""" | |
pass | |
def load_image_from_url(url): | |
try: | |
logging.info(f"Loading image from URL: {url}") | |
# Check the file extension | |
valid_extensions = ['jpg', 'jpeg', 'png', 'webp'] | |
file_extension = os.path.splitext(url)[1][1:].lower() | |
if file_extension not in valid_extensions: | |
raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}") | |
response = requests.get(url) | |
response.raise_for_status() # Check if the request was successful | |
content_type = response.headers['Content-Type'] | |
logging.info(f"Content-Type: {content_type}") | |
# Check if the content type is an image | |
if 'image' not in content_type: | |
raise CustomImageLoadingError(f"URL does not point to an image: {content_type}") | |
img = Image.open(BytesIO(response.content)).convert('RGB') | |
logging.info("Image successfully loaded and converted to RGB") | |
return img | |
except requests.HTTPError as e: | |
logging.error(f"HTTPError while loading image: {e}") | |
raise CustomImageLoadingError(f"Error loading image from URL: {e}") | |
except UnidentifiedImageError as e: | |
logging.error(f"UnidentifiedImageError while loading image: {e}") | |
raise CustomImageLoadingError(f"Cannot identify image file: {e}") | |
except requests.RequestException as e: | |
logging.error(f"RequestException while loading image: {e}") | |
raise CustomImageLoadingError(f"Error loading image from URL: {e}") | |
except Exception as e: | |
logging.error(f"Unexpected error while loading image: {e}") | |
raise CustomImageLoadingError(f"Error loading image from URL: {e}") | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
return transform(image).unsqueeze(0) | |
# Streamlit app | |
st.title("Cat and Dog Classification with VGG16-KAN") | |
st.sidebar.title("Upload Images") | |
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"]) | |
image_url = st.sidebar.text_input("Or enter image URL...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = load_model('weights/best_model_vgg16_KAN.pth', device) | |
img = None | |
if uploaded_file is not None: | |
logging.info("Image uploaded via file uploader") | |
img = Image.open(uploaded_file).convert('RGB') | |
elif image_url: | |
try: | |
img = load_image_from_url(image_url) | |
except CustomImageLoadingError as e: | |
st.sidebar.error(str(e)) | |
except Exception as e: | |
st.sidebar.error(f"Unexpected error: {e}") | |
st.sidebar.write("-----") | |
# Define your information for the footer | |
name = "Wayan Dadang" | |
st.sidebar.write("Follow me on:") | |
# Create a footer section with links and copyright information | |
st.sidebar.markdown(f""" | |
[LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/) | |
[GitHub](https://github.com/Wayan123) | |
[Resume](https://wayan123.github.io/) | |
© {name} - {2024} | |
""", unsafe_allow_html=True) | |
if img is not None: | |
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True) | |
if st.button('Predict'): | |
img_tensor = preprocess_image(img).to(device) | |
with torch.no_grad(): | |
output = model(img_tensor) | |
prob = torch.sigmoid(output).item() | |
st.write(f"Prediction: {prob:.4f}") | |
if prob < 0.5: | |
st.write("This image is classified as a Cat.") | |
else: | |
st.write("This image is classified as a Dog.") | |