solara-esport-highlights / inference.py
lunde's picture
Initial commit
bd65e34
raw
history blame
1.74 kB
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
import polars as pl
import lightning as L
from data_utils.frame_dataset import FrameDataset
import torch
from models.lightning_wrapper import LightningWrapper
def run_inference(
model_path: Path,
image_folder: Path,
aggregate_duration: int = 30,
fps: int = 3,
) -> pl.DataFrame:
model = LightningWrapper.load_from_checkpoint(model_path)
trainer = L.Trainer()
paths = list(image_folder.rglob("*.jpg"))
df = pl.DataFrame(
{"path": paths, "frame": [int(p.stem.removeprefix("img")) for p in paths]}
).sort("frame")
ds = FrameDataset(df, model.get_transforms(is_training=False), 1, is_train=False)
dls = DataLoader(ds, batch_size=32, num_workers=2, pin_memory=True)
preds_list: list[torch.Tensor] = trainer.predict(model, dataloaders=dls) # type: ignore
preds = torch.cat(preds_list)
pred_class = torch.argmax(preds, dim=1)
preds_class = np.repeat(pred_class.numpy(), ds.frames_per_clip)
df = df.with_columns(preds=pl.Series(preds_class))
df_g = df.group_by(pl.col("frame") // (aggregate_duration * fps)).agg(
pl.sum("preds")
)
seconds = pl.col("frame")
df_g = (
df_g.with_columns(pl.col("frame") * aggregate_duration)
.with_columns(
hour=seconds // (60 * 60), minute=(seconds // 60) % 60, second=seconds % 60
)
.with_columns(
timestamp=pl.datetime(
year=2023,
month=12,
day=10,
hour=pl.col("hour"),
minute="minute",
second="second",
)
)
.sort("timestamp")
)
return df_g