from typing import Any, List, Dict | |
import torch | |
from chronos import ChronosPipeline | |
class EndpointHandler: | |
def __init__(self) -> None: | |
self.pipeline = ChronosPipeline.from_pretrained("amazon/chronos-t5-tiny") | |
def __call__(self, data: Any) -> List[Dict[str, float]]: | |
inputs = data.pop("inputs") | |
# parameters = data.pop("parameters", {"prediction_length"}) | |
forecast = self.pipeline.predict( | |
torch.tensor(inputs["context"]), prediction_length=5 | |
) | |
return {"forecast": forecast.tolist()} | |