File size: 556 Bytes
eec04a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from typing import Any, List, Dict
import torch

from chronos import ChronosPipeline


class EndpointHandler:
    def __init__(self) -> None:
        self.pipeline = ChronosPipeline.from_pretrained("amazon/chronos-t5-tiny")

    def __call__(self, data: Any) -> List[Dict[str, float]]:
        inputs = data.pop("inputs")
        # parameters = data.pop("parameters", {"prediction_length"})

        forecast = self.pipeline.predict(
            torch.tensor(inputs["context"]), prediction_length=5
        )
        return {"forecast": forecast.tolist()}