Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
import lightning as L | |
import torchmetrics | |
import timm | |
class LightningWrapper(L.LightningModule): | |
def __init__(self, timm_model: str, num_classes: int, learning_rate: float = 1e-3): | |
super().__init__() | |
self.timm_model = timm_model | |
self.lr = learning_rate | |
self.model = timm.create_model( | |
self.timm_model, pretrained=True, num_classes=num_classes | |
) | |
self.save_hyperparameters(ignore=["model"]) | |
metrics = torchmetrics.MetricCollection( | |
{ | |
"accuracy": torchmetrics.Accuracy( | |
task="multiclass", num_classes=self.model.num_classes | |
) | |
} | |
) | |
self.train_metrics = metrics.clone(prefix="train_") | |
self.val_metrics = metrics.clone(prefix="val_") | |
def get_transforms(self, is_training: bool): | |
data_config = timm.data.resolve_model_data_config(self.timm_model) | |
return timm.data.create_transform(**data_config, is_training=is_training) | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.cross_entropy(logits, y) | |
self.log("train_loss", loss) | |
self.train_metrics(logits, y) | |
self.log_dict(self.train_metrics, prog_bar=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.cross_entropy(logits, y) | |
self.log("val_loss", loss) | |
self.val_metrics(logits, y) | |
self.log_dict(self.val_metrics, prog_bar=True) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) | |
return optimizer | |