HazardSpence's picture
Update handler.py
f652866 verified
raw
history blame contribute delete
877 Bytes
from typing import List, Dict, Any
from logits import LogitsPredictor
class EndpointHandler:
def __init__(self, path=""):
self.predictor = LogitsPredictor()
self.predictor.setup(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Extract parameters from the data dictionary
trg_text = data.pop("inputs", "")
parameters = data.get("parameters", {})
prefix_text = parameters.get("prefix_text", "")
context_length = parameters.get("context_length", 1024)
stride = parameters.get("stride", 512)
topk = parameters.get("topk", -1)
perf_metadata = parameters.get("perf_metadata", False)
return self.predictor.predict(
trg_text,
prefix_text,
context_length,
stride,
topk,
perf_metadata
)