|
import json |
|
import socket |
|
from typing import Optional |
|
|
|
from inference.core import logger |
|
from inference.enterprise.stream_management.manager.entities import ErrorType |
|
from inference.enterprise.stream_management.manager.errors import ( |
|
MalformedHeaderError, |
|
MalformedPayloadError, |
|
TransmissionChannelClosed, |
|
) |
|
from inference.enterprise.stream_management.manager.serialisation import ( |
|
prepare_error_response, |
|
) |
|
|
|
|
|
def receive_socket_data( |
|
source: socket.socket, header_size: int, buffer_size: int |
|
) -> dict: |
|
header = source.recv(header_size) |
|
if len(header) != header_size: |
|
raise MalformedHeaderError( |
|
f"Expected header size: {header_size}, received: {header}" |
|
) |
|
payload_size = int.from_bytes(bytes=header, byteorder="big") |
|
if payload_size <= 0: |
|
raise MalformedHeaderError( |
|
f"Header is indicating non positive payload size: {payload_size}" |
|
) |
|
received = b"" |
|
while len(received) < payload_size: |
|
chunk = source.recv(buffer_size) |
|
if len(chunk) == 0: |
|
raise TransmissionChannelClosed( |
|
"Socket was closed to read before payload was decoded." |
|
) |
|
received += chunk |
|
try: |
|
return json.loads(received) |
|
except ValueError: |
|
raise MalformedPayloadError("Received payload that is not in a JSON format") |
|
|
|
|
|
def send_data_trough_socket( |
|
target: socket.socket, |
|
header_size: int, |
|
data: bytes, |
|
request_id: str, |
|
recover_from_overflow: bool = True, |
|
pipeline_id: Optional[str] = None, |
|
) -> None: |
|
try: |
|
data_size = len(data) |
|
header = data_size.to_bytes(length=header_size, byteorder="big") |
|
payload = header + data |
|
target.sendall(payload) |
|
except OverflowError as error: |
|
if not recover_from_overflow: |
|
logger.error(f"OverflowError was suppressed. {error}") |
|
return None |
|
error_response = prepare_error_response( |
|
request_id=request_id, |
|
error=error, |
|
error_type=ErrorType.INTERNAL_ERROR, |
|
pipeline_id=pipeline_id, |
|
) |
|
send_data_trough_socket( |
|
target=target, |
|
header_size=header_size, |
|
data=error_response, |
|
request_id=request_id, |
|
recover_from_overflow=False, |
|
pipeline_id=pipeline_id, |
|
) |
|
except Exception as error: |
|
logger.error(f"Could not send the response through socket. Error: {error}") |
|
|