|
from typing import List, Dict, Any |
|
from logits import LogitsPredictor |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.predictor = LogitsPredictor() |
|
self.predictor.setup(path) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
trg_text = data.pop("inputs", "") |
|
parameters = data.get("parameters", {}) |
|
prefix_text = parameters.get("prefix_text", "") |
|
context_length = parameters.get("context_length", 1024) |
|
stride = parameters.get("stride", 512) |
|
topk = parameters.get("topk", -1) |
|
perf_metadata = parameters.get("perf_metadata", False) |
|
|
|
return self.predictor.predict( |
|
trg_text, |
|
prefix_text, |
|
context_length, |
|
stride, |
|
topk, |
|
perf_metadata |
|
) |