|
import torch
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from model.model import ExcitometerModel
|
|
from data.dataset import load_dataset, preprocess_data
|
|
from data.preprocess import preprocess_data
|
|
|
|
|
|
batch_size = 32
|
|
learning_rate = 0.001
|
|
num_epochs = 10
|
|
num_classes = 10
|
|
|
|
|
|
model = ExcitometerModel(num_classes=num_classes)
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model.to(device)
|
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
|
|
|
|
def load_data():
|
|
|
|
train_data = load_dataset('train')
|
|
val_data = load_dataset('val')
|
|
|
|
train_dataset = Dataset(train_data, preprocess_data)
|
|
val_dataset = Dataset(val_data, preprocess_data)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
|
|
return train_loader, val_loader
|
|
|
|
train_loader, val_loader = load_data()
|
|
|
|
|
|
def train_epoch():
|
|
model.train()
|
|
running_loss = 0.0
|
|
correct = 0
|
|
total = 0
|
|
|
|
for inputs, labels in train_loader:
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
_, predicted = torch.max(outputs, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
epoch_loss = running_loss / len(train_loader.dataset)
|
|
epoch_acc = correct / total
|
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
def validate_epoch():
|
|
model.eval()
|
|
running_loss = 0.0
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
for inputs, labels in val_loader:
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
_, predicted = torch.max(outputs, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
epoch_loss = running_loss / len(val_loader.dataset)
|
|
epoch_acc = correct / total
|
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
train_loss, train_acc = train_epoch()
|
|
val_loss, val_acc = validate_epoch()
|
|
|
|
print(f'Epoch [{epoch+1}/{num_epochs}]')
|
|
print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}')
|
|
print(f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}')
|
|
|
|
|
|
torch.save(model.state_dict(), 'excitometer_model.pth')
|
|
|
|
print('Training complete. Model saved as excitometer_model.pth')
|
|
|