Spaces:
Running
Running
# Install necessary packages | |
# Ensure you have PyTorch, torchvision, and Streamlit installed | |
# You can install them using pip if you haven't already: | |
# pip install torch torchvision streamlit | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import datasets, models, transforms | |
from torch.utils.data import DataLoader | |
import numpy as np | |
import time | |
import os | |
import copy | |
import streamlit as st | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import torchvision.transforms as T | |
# Data transformations | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'val': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
} | |
# Load datasets | |
data_dir = 'path/to/data' | |
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) | |
for x in ['train', 'val']} | |
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) | |
for x in ['train', 'val']} | |
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} | |
class_names = image_datasets['train'].classes | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load the pre-trained model | |
model_ft = models.resnet18(pretrained=True) | |
num_ftrs = model_ft.fc.in_features | |
model_ft.fc = nn.Linear(num_ftrs, len(class_names)) | |
model_ft = model_ft.to(device) | |
# Define loss function and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) | |
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) | |
# Training and evaluation functions | |
def train_model(model, criterion, optimizer, scheduler, num_epochs=25): | |
since = time.time() | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_acc = 0.0 | |
for epoch in range(num_epochs): | |
print('Epoch {}/{}'.format(epoch, num_epochs - 1)) | |
print('-' * 10) | |
for phase in ['train', 'val']: | |
if phase == 'train': | |
model.train() | |
else: | |
model.eval() | |
running_loss = 0.0 | |
running_corrects = 0 | |
for inputs, labels in dataloaders[phase]: | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
optimizer.zero_grad() | |
with torch.set_grad_enabled(phase == 'train'): | |
outputs = model(inputs) | |
_, preds = torch.max(outputs, 1) | |
loss = criterion(outputs, labels) | |
if phase == 'train': | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() * inputs.size(0) | |
running_corrects += torch.sum(preds == labels.data) | |
if phase == 'train': | |
scheduler.step() | |
epoch_loss = running_loss / dataset_sizes[phase] | |
epoch_acc = running_corrects.double() / dataset_sizes[phase] | |
print('{} Loss: {:.4f} Acc: {:.4f}'.format( | |
phase, epoch_loss, epoch_acc)) | |
if phase == 'val' and epoch_acc > best_acc: | |
best_acc = epoch_acc | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
print() | |
time_elapsed = time.time() - since | |
print('Training complete in {:.0f}m {:.0f}s'.format( | |
time_elapsed // 60, time_elapsed % 60)) | |
print('Best val Acc: {:4f}'.format(best_acc)) | |
model.load_state_dict(best_model_wts) | |
return model | |
# Train the model | |
model_ft = train_model(model_ft, criterion, optimizer_ft, scheduler, num_epochs=25) | |
# Save the trained model | |
torch.save(model_ft.state_dict(), 'model_ft.pth') | |
# Streamlit Interface | |
st.title("Image Classification with Fine-tuned ResNet") | |
uploaded_file = st.file_uploader("Choose an image...", type="jpg") | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image.', use_column_width=True) | |
st.write("") | |
st.write("Classifying...") | |
model_ft = models.resnet18(pretrained=True) | |
num_ftrs = model_ft.fc.in_features | |
model_ft.fc = nn.Linear(num_ftrs, len(class_names)) | |
model_ft.load_state_dict(torch.load('model_ft.pth')) | |
model_ft = model_ft.to(device) | |
model_ft.eval() | |
preprocess = T.Compose([ | |
T.Resize(256), | |
T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
img = preprocess(image).unsqueeze(0) | |
img = img.to(device) | |
with torch.no_grad(): | |
outputs = model_ft(img) | |
_, preds = torch.max(outputs, 1) | |
predicted_class = class_names[preds[0]] | |
st.write(f"Predicted Class: {predicted_class}") | |
# Plotting the image with matplotlib | |
fig, ax = plt.subplots() | |
ax.imshow(image) | |
ax.set_title(f"Predicted: {predicted_class}") | |
st.pyplot(fig) | |