# Install necessary packages # Ensure you have PyTorch, torchvision, and Streamlit installed # You can install them using pip if you haven't already: # pip install torch torchvision streamlit import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, models, transforms from torch.utils.data import DataLoader import numpy as np import time import os import copy import streamlit as st from PIL import Image import matplotlib.pyplot as plt import torchvision.transforms as T # Data transformations data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } # Load datasets data_dir = 'path/to/data' image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load the pre-trained model model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, len(class_names)) model_ft = model_ft.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) # Training and evaluation functions def train_model(model, criterion, optimizer, scheduler, num_epochs=25): since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc)) if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) model.load_state_dict(best_model_wts) return model # Train the model model_ft = train_model(model_ft, criterion, optimizer_ft, scheduler, num_epochs=25) # Save the trained model torch.save(model_ft.state_dict(), 'model_ft.pth') # Streamlit Interface st.title("Image Classification with Fine-tuned ResNet") uploaded_file = st.file_uploader("Choose an image...", type="jpg") if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image.', use_column_width=True) st.write("") st.write("Classifying...") model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, len(class_names)) model_ft.load_state_dict(torch.load('model_ft.pth')) model_ft = model_ft.to(device) model_ft.eval() preprocess = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) img = preprocess(image).unsqueeze(0) img = img.to(device) with torch.no_grad(): outputs = model_ft(img) _, preds = torch.max(outputs, 1) predicted_class = class_names[preds[0]] st.write(f"Predicted Class: {predicted_class}") # Plotting the image with matplotlib fig, ax = plt.subplots() ax.imshow(image) ax.set_title(f"Predicted: {predicted_class}") st.pyplot(fig)