|
import torch |
|
import chatglm_cpp |
|
from typing import Dict, List, Any |
|
|
|
|
|
|
|
|
|
TURN_BREAKER = "<||turn_breaker||>" |
|
SYSTEM_SYMBOL = "<||system_symbol||>" |
|
USER_SYMBOL = "<||user_symbol||>" |
|
ASSISTANT_SYMBOL = "<||assistant_symbol||>" |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.pipeline = chatglm_cpp.Pipeline(f"{path}/q8_0_v2.bin") |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
|
|
str_messages = inputs.split(TURN_BREAKER) |
|
cpp_messages = [chatglm_cpp.ChatMessage( |
|
role="system", |
|
content=str_messages[0].replace(SYSTEM_SYMBOL, "") |
|
)] |
|
|
|
for msg in str_messages[1:]: |
|
if USER_SYMBOL in msg: |
|
cpp_messages.append( |
|
chatglm_cpp.ChatMessage( |
|
role="user", |
|
content=msg.replace(USER_SYMBOL, "") |
|
)) |
|
else: |
|
cpp_messages.append( |
|
chatglm_cpp.ChatMessage( |
|
role="assistant", |
|
content=msg.replace(ASSISTANT_SYMBOL, "") |
|
)) |
|
|
|
|
|
if parameters is not None: |
|
prediction = self.pipeline.chat(cpp_messages, **parameters) |
|
else: |
|
prediction = self.pipeline.chat(cpp_messages) |
|
|
|
return prediction.content |