File size: 556 Bytes
eec04a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from typing import Any, List, Dict
import torch
from chronos import ChronosPipeline
class EndpointHandler:
def __init__(self) -> None:
self.pipeline = ChronosPipeline.from_pretrained("amazon/chronos-t5-tiny")
def __call__(self, data: Any) -> List[Dict[str, float]]:
inputs = data.pop("inputs")
# parameters = data.pop("parameters", {"prediction_length"})
forecast = self.pipeline.predict(
torch.tensor(inputs["context"]), prediction_length=5
)
return {"forecast": forecast.tolist()}
|