shukdevdatta123's picture
Update app.py
5dc35f8 verified
import streamlit as st
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
# Function to calculate accuracy
def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim=1)
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
class ImageClassificationBase(nn.Module):
def training_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels) # Calculate loss
return loss
def validation_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels) # Calculate loss
acc = accuracy(out, labels) # Calculate accuracy
return {'val_loss': loss.detach(), 'val_acc': acc}
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
epoch, result['train_loss'], result['val_loss'], result['val_acc']))
# Load the model architecture
class CnnModel(ImageClassificationBase):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0),
nn.MaxPool2d(2, 2),
nn.Dropout(p=0.25),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 10, kernel_size=1, stride=1, padding=1),
nn.AvgPool2d(kernel_size=6),
nn.Flatten(),
nn.Softmax(dim=1)
)
def forward(self, xb):
return self.network(xb)
# Load the saved model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CnnModel()
model.load_state_dict(torch.load('svhn_cnn.pth', map_location=device))
model = model.to(device)
model.eval()
# Function to predict the digit and display the image
def predict_and_display(image_path, model):
# Preprocess the image
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
# Perform prediction
with torch.no_grad():
outputs = model(image_tensor)
_, predicted = torch.max(outputs, 1)
# Display the image and prediction
st.image(image, caption='Uploaded Image', use_container_width=True)
st.write(f"Predicted Digit: {predicted.item()}")
# Streamlit UI
st.title("Digit Prediction from Image")
st.write("Upload an image of a digit for prediction:")
uploaded_image = st.file_uploader("Choose an image...", type="png")
if uploaded_image is not None:
# Save the uploaded image temporarily
image_path = "uploaded_image.png"
with open(image_path, "wb") as f:
f.write(uploaded_image.getbuffer())
# Display prediction
predict_and_display(image_path, model)