dolly-v2-12b-endpoint / handler.py
dlowl's picture
Load dolly-v2 model with remote code trusted and full text returned (so it's usable with langchain)
0fabff5
raw
history blame contribute delete
892 Bytes
import torch
from typing import Dict, Any, List
from transformers import pipeline
class EndpointHandler:
def __init__(
self,
path: str,
) -> None:
self.pipeline = pipeline(model=path, torch_dtype=torch.bfloat16, trust_remote_code=True, return_full_text=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
# postprocess the prediction
return prediction