from pathlib import Path import numpy as np from torch.utils.data import DataLoader import polars as pl import lightning as L from data_utils.frame_dataset import FrameDataset import torch from models.lightning_wrapper import LightningWrapper def run_inference( model_path: Path, image_folder: Path, aggregate_duration: int = 30, fps: int = 3, ) -> pl.DataFrame: model = LightningWrapper.load_from_checkpoint(model_path) trainer = L.Trainer() paths = list(image_folder.rglob("*.jpg")) df = pl.DataFrame( {"path": paths, "frame": [int(p.stem.removeprefix("img")) for p in paths]} ).sort("frame") ds = FrameDataset(df, model.get_transforms(is_training=False), 1, is_train=False) dls = DataLoader(ds, batch_size=32, num_workers=2, pin_memory=True) preds_list: list[torch.Tensor] = trainer.predict(model, dataloaders=dls) # type: ignore preds = torch.cat(preds_list) pred_class = torch.argmax(preds, dim=1) preds_class = np.repeat(pred_class.numpy(), ds.frames_per_clip) df = df.with_columns(preds=pl.Series(preds_class)) df_g = df.group_by(pl.col("frame") // (aggregate_duration * fps)).agg( pl.sum("preds") ) seconds = pl.col("frame") df_g = ( df_g.with_columns(pl.col("frame") * aggregate_duration) .with_columns( hour=seconds // (60 * 60), minute=(seconds // 60) % 60, second=seconds % 60 ) .with_columns( timestamp=pl.datetime( year=2023, month=12, day=10, hour=pl.col("hour"), minute="minute", second="second", ) ) .sort("timestamp") ) return df_g