|
import numpy as np |
|
import torch |
|
|
|
|
|
def batch_model_predict(model_predict, inputs, batch_size=32): |
|
"""Runs prediction on iterable ``inputs`` using batch size ``batch_size``. |
|
|
|
Aggregates all predictions into an ``np.ndarray``. |
|
""" |
|
outputs = [] |
|
i = 0 |
|
while i < len(inputs): |
|
batch = inputs[i : i + batch_size] |
|
batch_preds = model_predict(batch) |
|
|
|
|
|
|
|
if isinstance(batch_preds, str): |
|
batch_preds = [batch_preds] |
|
|
|
|
|
if isinstance(batch_preds, torch.Tensor): |
|
batch_preds = batch_preds.cpu() |
|
|
|
|
|
if not isinstance(batch_preds, np.ndarray): |
|
batch_preds = np.array(batch_preds) |
|
outputs.append(batch_preds) |
|
i += batch_size |
|
|
|
return np.concatenate(outputs, axis=0) |
|
|