Spaces:
Runtime error
Runtime error
File size: 1,015 Bytes
63775f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import numpy as np
import torch
def batch_model_predict(model_predict, inputs, batch_size=32):
"""Runs prediction on iterable ``inputs`` using batch size ``batch_size``.
Aggregates all predictions into an ``np.ndarray``.
"""
outputs = []
i = 0
while i < len(inputs):
batch = inputs[i : i + batch_size]
batch_preds = model_predict(batch)
# Some seq-to-seq models will return a single string as a prediction
# for a single-string list. Wrap these in a list.
if isinstance(batch_preds, str):
batch_preds = [batch_preds]
# Get PyTorch tensors off of other devices.
if isinstance(batch_preds, torch.Tensor):
batch_preds = batch_preds.cpu()
# Cast all predictions iterables to ``np.ndarray`` types.
if not isinstance(batch_preds, np.ndarray):
batch_preds = np.array(batch_preds)
outputs.append(batch_preds)
i += batch_size
return np.concatenate(outputs, axis=0)
|