import streamlit as st from PIL import Image import torch import torchvision.transforms as transforms import matplotlib.pyplot as plt import torch.nn as nn import torch.nn.functional as F # Function to calculate accuracy def accuracy(outputs, labels): _, preds = torch.max(outputs, dim=1) return torch.tensor(torch.sum(preds == labels).item() / len(preds)) class ImageClassificationBase(nn.Module): def training_step(self, batch): images, labels = batch out = self(images) # Generate predictions loss = F.cross_entropy(out, labels) # Calculate loss return loss def validation_step(self, batch): images, labels = batch out = self(images) # Generate predictions loss = F.cross_entropy(out, labels) # Calculate loss acc = accuracy(out, labels) # Calculate accuracy return {'val_loss': loss.detach(), 'val_acc': acc} def validation_epoch_end(self, outputs): batch_losses = [x['val_loss'] for x in outputs] epoch_loss = torch.stack(batch_losses).mean() # Combine losses batch_accs = [x['val_acc'] for x in outputs] epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} def epoch_end(self, epoch, result): print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( epoch, result['train_loss'], result['val_loss'], result['val_acc'])) # Load the model architecture class CnnModel(ImageClassificationBase): def __init__(self): super().__init__() self.network = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0), nn.MaxPool2d(2, 2), nn.Dropout(p=0.25), nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 10, kernel_size=1, stride=1, padding=1), nn.AvgPool2d(kernel_size=6), nn.Flatten(), nn.Softmax(dim=1) ) def forward(self, xb): return self.network(xb) # Load the saved model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CnnModel() model.load_state_dict(torch.load('svhn_cnn.pth', map_location=device)) model = model.to(device) model.eval() # Function to predict the digit and display the image def predict_and_display(image_path, model): # Preprocess the image transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) image = Image.open(image_path).convert('RGB') image_tensor = transform(image).unsqueeze(0).to(device) # Perform prediction with torch.no_grad(): outputs = model(image_tensor) _, predicted = torch.max(outputs, 1) # Display the image and prediction st.image(image, caption='Uploaded Image', use_container_width=True) st.write(f"Predicted Digit: {predicted.item()}") # Streamlit UI st.title("Digit Prediction from Image") st.write("Upload an image of a digit for prediction:") uploaded_image = st.file_uploader("Choose an image...", type="png") if uploaded_image is not None: # Save the uploaded image temporarily image_path = "uploaded_image.png" with open(image_path, "wb") as f: f.write(uploaded_image.getbuffer()) # Display prediction predict_and_display(image_path, model)