|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Service-side implementation of gRPC Python.""" |
|
|
|
from __future__ import annotations |
|
|
|
import collections |
|
from concurrent import futures |
|
import enum |
|
import logging |
|
import threading |
|
import time |
|
import traceback |
|
from typing import (Any, Callable, Iterable, Iterator, List, Mapping, Optional, |
|
Sequence, Set, Tuple, Union) |
|
|
|
import grpc |
|
from grpc import _common |
|
from grpc import _compression |
|
from grpc import _interceptor |
|
from grpc._cython import cygrpc |
|
from grpc._typing import ArityAgnosticMethodHandler |
|
from grpc._typing import ChannelArgumentType |
|
from grpc._typing import DeserializingFunction |
|
from grpc._typing import MetadataType |
|
from grpc._typing import NullaryCallbackType |
|
from grpc._typing import ResponseType |
|
from grpc._typing import SerializingFunction |
|
from grpc._typing import ServerCallbackTag |
|
from grpc._typing import ServerTagCallbackType |
|
|
|
_LOGGER = logging.getLogger(__name__) |
|
|
|
_SHUTDOWN_TAG = 'shutdown' |
|
_REQUEST_CALL_TAG = 'request_call' |
|
|
|
_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server' |
|
_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata' |
|
_RECEIVE_MESSAGE_TOKEN = 'receive_message' |
|
_SEND_MESSAGE_TOKEN = 'send_message' |
|
_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( |
|
'send_initial_metadata * send_message') |
|
_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server' |
|
_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( |
|
'send_initial_metadata * send_status_from_server') |
|
|
|
_OPEN = 'open' |
|
_CLOSED = 'closed' |
|
_CANCELLED = 'cancelled' |
|
|
|
_EMPTY_FLAGS = 0 |
|
|
|
_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 |
|
_INF_TIMEOUT = 1e9 |
|
|
|
|
|
def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes: |
|
return request_event.batch_operations[0].message() |
|
|
|
|
|
def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode: |
|
cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) |
|
return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code |
|
|
|
|
|
def _completion_code(state: _RPCState) -> cygrpc.StatusCode: |
|
if state.code is None: |
|
return cygrpc.StatusCode.ok |
|
else: |
|
return _application_code(state.code) |
|
|
|
|
|
def _abortion_code(state: _RPCState, |
|
code: cygrpc.StatusCode) -> cygrpc.StatusCode: |
|
if state.code is None: |
|
return code |
|
else: |
|
return _application_code(state.code) |
|
|
|
|
|
def _details(state: _RPCState) -> bytes: |
|
return b'' if state.details is None else state.details |
|
|
|
|
|
class _HandlerCallDetails( |
|
collections.namedtuple('_HandlerCallDetails', ( |
|
'method', |
|
'invocation_metadata', |
|
)), grpc.HandlerCallDetails): |
|
pass |
|
|
|
|
|
class _RPCState(object): |
|
condition: threading.Condition |
|
due = Set[str] |
|
request: Any |
|
client: str |
|
initial_metadata_allowed: bool |
|
compression_algorithm: Optional[grpc.Compression] |
|
disable_next_compression: bool |
|
trailing_metadata: Optional[MetadataType] |
|
code: Optional[grpc.StatusCode] |
|
details: Optional[bytes] |
|
statused: bool |
|
rpc_errors: List[Exception] |
|
callbacks: Optional[List[NullaryCallbackType]] |
|
aborted: bool |
|
|
|
def __init__(self): |
|
self.condition = threading.Condition() |
|
self.due = set() |
|
self.request = None |
|
self.client = _OPEN |
|
self.initial_metadata_allowed = True |
|
self.compression_algorithm = None |
|
self.disable_next_compression = False |
|
self.trailing_metadata = None |
|
self.code = None |
|
self.details = None |
|
self.statused = False |
|
self.rpc_errors = [] |
|
self.callbacks = [] |
|
self.aborted = False |
|
|
|
|
|
def _raise_rpc_error(state: _RPCState) -> None: |
|
rpc_error = grpc.RpcError() |
|
state.rpc_errors.append(rpc_error) |
|
raise rpc_error |
|
|
|
|
|
def _possibly_finish_call(state: _RPCState, |
|
token: str) -> ServerTagCallbackType: |
|
state.due.remove(token) |
|
if not _is_rpc_state_active(state) and not state.due: |
|
callbacks = state.callbacks |
|
state.callbacks = None |
|
return state, callbacks |
|
else: |
|
return None, () |
|
|
|
|
|
def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag: |
|
|
|
def send_status_from_server(unused_send_status_from_server_event): |
|
with state.condition: |
|
return _possibly_finish_call(state, token) |
|
|
|
return send_status_from_server |
|
|
|
|
|
def _get_initial_metadata( |
|
state: _RPCState, |
|
metadata: Optional[MetadataType]) -> Optional[MetadataType]: |
|
with state.condition: |
|
if state.compression_algorithm: |
|
compression_metadata = ( |
|
_compression.compression_algorithm_to_metadata( |
|
state.compression_algorithm),) |
|
if metadata is None: |
|
return compression_metadata |
|
else: |
|
return compression_metadata + tuple(metadata) |
|
else: |
|
return metadata |
|
|
|
|
|
def _get_initial_metadata_operation( |
|
state: _RPCState, metadata: Optional[MetadataType]) -> cygrpc.Operation: |
|
operation = cygrpc.SendInitialMetadataOperation( |
|
_get_initial_metadata(state, metadata), _EMPTY_FLAGS) |
|
return operation |
|
|
|
|
|
def _abort(state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, |
|
details: bytes) -> None: |
|
if state.client is not _CANCELLED: |
|
effective_code = _abortion_code(state, code) |
|
effective_details = details if state.details is None else state.details |
|
if state.initial_metadata_allowed: |
|
operations = ( |
|
_get_initial_metadata_operation(state, None), |
|
cygrpc.SendStatusFromServerOperation(state.trailing_metadata, |
|
effective_code, |
|
effective_details, |
|
_EMPTY_FLAGS), |
|
) |
|
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN |
|
else: |
|
operations = (cygrpc.SendStatusFromServerOperation( |
|
state.trailing_metadata, effective_code, effective_details, |
|
_EMPTY_FLAGS),) |
|
token = _SEND_STATUS_FROM_SERVER_TOKEN |
|
call.start_server_batch(operations, |
|
_send_status_from_server(state, token)) |
|
state.statused = True |
|
state.due.add(token) |
|
|
|
|
|
def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag: |
|
|
|
def receive_close_on_server(receive_close_on_server_event): |
|
with state.condition: |
|
if receive_close_on_server_event.batch_operations[0].cancelled(): |
|
state.client = _CANCELLED |
|
elif state.client is _OPEN: |
|
state.client = _CLOSED |
|
state.condition.notify_all() |
|
return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) |
|
|
|
return receive_close_on_server |
|
|
|
|
|
def _receive_message( |
|
state: _RPCState, call: cygrpc.Call, |
|
request_deserializer: Optional[DeserializingFunction] |
|
) -> ServerCallbackTag: |
|
|
|
def receive_message(receive_message_event): |
|
serialized_request = _serialized_request(receive_message_event) |
|
if serialized_request is None: |
|
with state.condition: |
|
if state.client is _OPEN: |
|
state.client = _CLOSED |
|
state.condition.notify_all() |
|
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) |
|
else: |
|
request = _common.deserialize(serialized_request, |
|
request_deserializer) |
|
with state.condition: |
|
if request is None: |
|
_abort(state, call, cygrpc.StatusCode.internal, |
|
b'Exception deserializing request!') |
|
else: |
|
state.request = request |
|
state.condition.notify_all() |
|
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) |
|
|
|
return receive_message |
|
|
|
|
|
def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag: |
|
|
|
def send_initial_metadata(unused_send_initial_metadata_event): |
|
with state.condition: |
|
return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) |
|
|
|
return send_initial_metadata |
|
|
|
|
|
def _send_message(state: _RPCState, token: str) -> ServerCallbackTag: |
|
|
|
def send_message(unused_send_message_event): |
|
with state.condition: |
|
state.condition.notify_all() |
|
return _possibly_finish_call(state, token) |
|
|
|
return send_message |
|
|
|
|
|
class _Context(grpc.ServicerContext): |
|
_rpc_event: cygrpc.BaseEvent |
|
_state: _RPCState |
|
request_deserializer: Optional[DeserializingFunction] |
|
|
|
def __init__(self, rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
request_deserializer: Optional[DeserializingFunction]): |
|
self._rpc_event = rpc_event |
|
self._state = state |
|
self._request_deserializer = request_deserializer |
|
|
|
def is_active(self) -> bool: |
|
with self._state.condition: |
|
return _is_rpc_state_active(self._state) |
|
|
|
def time_remaining(self) -> float: |
|
return max(self._rpc_event.call_details.deadline - time.time(), 0) |
|
|
|
def cancel(self) -> None: |
|
self._rpc_event.call.cancel() |
|
|
|
def add_callback(self, callback: NullaryCallbackType) -> bool: |
|
with self._state.condition: |
|
if self._state.callbacks is None: |
|
return False |
|
else: |
|
self._state.callbacks.append(callback) |
|
return True |
|
|
|
def disable_next_message_compression(self) -> None: |
|
with self._state.condition: |
|
self._state.disable_next_compression = True |
|
|
|
def invocation_metadata(self) -> Optional[MetadataType]: |
|
return self._rpc_event.invocation_metadata |
|
|
|
def peer(self) -> str: |
|
return _common.decode(self._rpc_event.call.peer()) |
|
|
|
def peer_identities(self) -> Optional[Sequence[bytes]]: |
|
return cygrpc.peer_identities(self._rpc_event.call) |
|
|
|
def peer_identity_key(self) -> Optional[str]: |
|
id_key = cygrpc.peer_identity_key(self._rpc_event.call) |
|
return id_key if id_key is None else _common.decode(id_key) |
|
|
|
def auth_context(self) -> Mapping[str, Sequence[bytes]]: |
|
auth_context = cygrpc.auth_context(self._rpc_event.call) |
|
auth_context_dict = {} if auth_context is None else auth_context |
|
return { |
|
_common.decode(key): value |
|
for key, value in auth_context_dict.items() |
|
} |
|
|
|
def set_compression(self, compression: grpc.Compression) -> None: |
|
with self._state.condition: |
|
self._state.compression_algorithm = compression |
|
|
|
def send_initial_metadata(self, initial_metadata: MetadataType) -> None: |
|
with self._state.condition: |
|
if self._state.client is _CANCELLED: |
|
_raise_rpc_error(self._state) |
|
else: |
|
if self._state.initial_metadata_allowed: |
|
operation = _get_initial_metadata_operation( |
|
self._state, initial_metadata) |
|
self._rpc_event.call.start_server_batch( |
|
(operation,), _send_initial_metadata(self._state)) |
|
self._state.initial_metadata_allowed = False |
|
self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) |
|
else: |
|
raise ValueError('Initial metadata no longer allowed!') |
|
|
|
def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: |
|
with self._state.condition: |
|
self._state.trailing_metadata = trailing_metadata |
|
|
|
def trailing_metadata(self) -> Optional[MetadataType]: |
|
return self._state.trailing_metadata |
|
|
|
def abort(self, code: grpc.StatusCode, details: str) -> None: |
|
|
|
if code == grpc.StatusCode.OK: |
|
_LOGGER.error( |
|
'abort() called with StatusCode.OK; returning UNKNOWN') |
|
code = grpc.StatusCode.UNKNOWN |
|
details = '' |
|
with self._state.condition: |
|
self._state.code = code |
|
self._state.details = _common.encode(details) |
|
self._state.aborted = True |
|
raise Exception() |
|
|
|
def abort_with_status(self, status: grpc.Status) -> None: |
|
self._state.trailing_metadata = status.trailing_metadata |
|
self.abort(status.code, status.details) |
|
|
|
def set_code(self, code: grpc.StatusCode) -> None: |
|
with self._state.condition: |
|
self._state.code = code |
|
|
|
def code(self) -> grpc.StatusCode: |
|
return self._state.code |
|
|
|
def set_details(self, details: str) -> None: |
|
with self._state.condition: |
|
self._state.details = _common.encode(details) |
|
|
|
def details(self) -> bytes: |
|
return self._state.details |
|
|
|
def _finalize_state(self) -> None: |
|
pass |
|
|
|
|
|
class _RequestIterator(object): |
|
_state: _RPCState |
|
_call: cygrpc.Call |
|
_request_deserializer: Optional[DeserializingFunction] |
|
|
|
def __init__(self, state: _RPCState, call: cygrpc.Call, |
|
request_deserializer: Optional[DeserializingFunction]): |
|
self._state = state |
|
self._call = call |
|
self._request_deserializer = request_deserializer |
|
|
|
def _raise_or_start_receive_message(self) -> None: |
|
if self._state.client is _CANCELLED: |
|
_raise_rpc_error(self._state) |
|
elif not _is_rpc_state_active(self._state): |
|
raise StopIteration() |
|
else: |
|
self._call.start_server_batch( |
|
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), |
|
_receive_message(self._state, self._call, |
|
self._request_deserializer)) |
|
self._state.due.add(_RECEIVE_MESSAGE_TOKEN) |
|
|
|
def _look_for_request(self) -> Any: |
|
if self._state.client is _CANCELLED: |
|
_raise_rpc_error(self._state) |
|
elif (self._state.request is None and |
|
_RECEIVE_MESSAGE_TOKEN not in self._state.due): |
|
raise StopIteration() |
|
else: |
|
request = self._state.request |
|
self._state.request = None |
|
return request |
|
|
|
raise AssertionError() |
|
|
|
def _next(self) -> Any: |
|
with self._state.condition: |
|
self._raise_or_start_receive_message() |
|
while True: |
|
self._state.condition.wait() |
|
request = self._look_for_request() |
|
if request is not None: |
|
return request |
|
|
|
def __iter__(self) -> _RequestIterator: |
|
return self |
|
|
|
def __next__(self) -> Any: |
|
return self._next() |
|
|
|
def next(self) -> Any: |
|
return self._next() |
|
|
|
|
|
def _unary_request( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
request_deserializer: Optional[DeserializingFunction] |
|
) -> Callable[[], Any]: |
|
|
|
def unary_request(): |
|
with state.condition: |
|
if not _is_rpc_state_active(state): |
|
return None |
|
else: |
|
rpc_event.call.start_server_batch( |
|
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), |
|
_receive_message(state, rpc_event.call, |
|
request_deserializer)) |
|
state.due.add(_RECEIVE_MESSAGE_TOKEN) |
|
while True: |
|
state.condition.wait() |
|
if state.request is None: |
|
if state.client is _CLOSED: |
|
details = '"{}" requires exactly one request message.'.format( |
|
rpc_event.call_details.method) |
|
_abort(state, rpc_event.call, |
|
cygrpc.StatusCode.unimplemented, |
|
_common.encode(details)) |
|
return None |
|
elif state.client is _CANCELLED: |
|
return None |
|
else: |
|
request = state.request |
|
state.request = None |
|
return request |
|
|
|
return unary_request |
|
|
|
|
|
def _call_behavior( |
|
rpc_event: cygrpc.BaseEvent, |
|
state: _RPCState, |
|
behavior: ArityAgnosticMethodHandler, |
|
argument: Any, |
|
request_deserializer: Optional[DeserializingFunction], |
|
send_response_callback: Optional[Callable[[ResponseType], None]] = None |
|
) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]: |
|
from grpc import _create_servicer_context |
|
with _create_servicer_context(rpc_event, state, |
|
request_deserializer) as context: |
|
try: |
|
response_or_iterator = None |
|
if send_response_callback is not None: |
|
response_or_iterator = behavior(argument, context, |
|
send_response_callback) |
|
else: |
|
response_or_iterator = behavior(argument, context) |
|
return response_or_iterator, True |
|
except Exception as exception: |
|
with state.condition: |
|
if state.aborted: |
|
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown, |
|
b'RPC Aborted') |
|
elif exception not in state.rpc_errors: |
|
try: |
|
details = 'Exception calling application: {}'.format( |
|
exception) |
|
except Exception: |
|
details = 'Calling application raised unprintable Exception!' |
|
traceback.print_exc() |
|
_LOGGER.exception(details) |
|
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown, |
|
_common.encode(details)) |
|
return None, False |
|
|
|
|
|
def _take_response_from_response_iterator( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
response_iterator: Iterator[ResponseType]) -> Tuple[ResponseType, bool]: |
|
try: |
|
return next(response_iterator), True |
|
except StopIteration: |
|
return None, True |
|
except Exception as exception: |
|
with state.condition: |
|
if state.aborted: |
|
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown, |
|
b'RPC Aborted') |
|
elif exception not in state.rpc_errors: |
|
details = 'Exception iterating responses: {}'.format(exception) |
|
_LOGGER.exception(details) |
|
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown, |
|
_common.encode(details)) |
|
return None, False |
|
|
|
|
|
def _serialize_response( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, response: Any, |
|
response_serializer: Optional[SerializingFunction]) -> Optional[bytes]: |
|
serialized_response = _common.serialize(response, response_serializer) |
|
if serialized_response is None: |
|
with state.condition: |
|
_abort(state, rpc_event.call, cygrpc.StatusCode.internal, |
|
b'Failed to serialize response!') |
|
return None |
|
else: |
|
return serialized_response |
|
|
|
|
|
def _get_send_message_op_flags_from_state( |
|
state: _RPCState) -> Union[int, cygrpc.WriteFlag]: |
|
if state.disable_next_compression: |
|
return cygrpc.WriteFlag.no_compress |
|
else: |
|
return _EMPTY_FLAGS |
|
|
|
|
|
def _reset_per_message_state(state: _RPCState) -> None: |
|
with state.condition: |
|
state.disable_next_compression = False |
|
|
|
|
|
def _send_response(rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
serialized_response: bytes) -> bool: |
|
with state.condition: |
|
if not _is_rpc_state_active(state): |
|
return False |
|
else: |
|
if state.initial_metadata_allowed: |
|
operations = ( |
|
_get_initial_metadata_operation(state, None), |
|
cygrpc.SendMessageOperation( |
|
serialized_response, |
|
_get_send_message_op_flags_from_state(state)), |
|
) |
|
state.initial_metadata_allowed = False |
|
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN |
|
else: |
|
operations = (cygrpc.SendMessageOperation( |
|
serialized_response, |
|
_get_send_message_op_flags_from_state(state)),) |
|
token = _SEND_MESSAGE_TOKEN |
|
rpc_event.call.start_server_batch(operations, |
|
_send_message(state, token)) |
|
state.due.add(token) |
|
_reset_per_message_state(state) |
|
while True: |
|
state.condition.wait() |
|
if token not in state.due: |
|
return _is_rpc_state_active(state) |
|
|
|
|
|
def _status(rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
serialized_response: Optional[bytes]) -> None: |
|
with state.condition: |
|
if state.client is not _CANCELLED: |
|
code = _completion_code(state) |
|
details = _details(state) |
|
operations = [ |
|
cygrpc.SendStatusFromServerOperation(state.trailing_metadata, |
|
code, details, |
|
_EMPTY_FLAGS), |
|
] |
|
if state.initial_metadata_allowed: |
|
operations.append(_get_initial_metadata_operation(state, None)) |
|
if serialized_response is not None: |
|
operations.append( |
|
cygrpc.SendMessageOperation( |
|
serialized_response, |
|
_get_send_message_op_flags_from_state(state))) |
|
rpc_event.call.start_server_batch( |
|
operations, |
|
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) |
|
state.statused = True |
|
_reset_per_message_state(state) |
|
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) |
|
|
|
|
|
def _unary_response_in_pool( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], |
|
request_deserializer: Optional[SerializingFunction], |
|
response_serializer: Optional[SerializingFunction]) -> None: |
|
cygrpc.install_context_from_request_call_event(rpc_event) |
|
try: |
|
argument = argument_thunk() |
|
if argument is not None: |
|
response, proceed = _call_behavior(rpc_event, state, behavior, |
|
argument, request_deserializer) |
|
if proceed: |
|
serialized_response = _serialize_response( |
|
rpc_event, state, response, response_serializer) |
|
if serialized_response is not None: |
|
_status(rpc_event, state, serialized_response) |
|
except Exception: |
|
traceback.print_exc() |
|
finally: |
|
cygrpc.uninstall_context() |
|
|
|
|
|
def _stream_response_in_pool( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], |
|
request_deserializer: Optional[DeserializingFunction], |
|
response_serializer: Optional[SerializingFunction]) -> None: |
|
cygrpc.install_context_from_request_call_event(rpc_event) |
|
|
|
def send_response(response: Any) -> None: |
|
if response is None: |
|
_status(rpc_event, state, None) |
|
else: |
|
serialized_response = _serialize_response(rpc_event, state, |
|
response, |
|
response_serializer) |
|
if serialized_response is not None: |
|
_send_response(rpc_event, state, serialized_response) |
|
|
|
try: |
|
argument = argument_thunk() |
|
if argument is not None: |
|
if hasattr(behavior, 'experimental_non_blocking' |
|
) and behavior.experimental_non_blocking: |
|
_call_behavior(rpc_event, |
|
state, |
|
behavior, |
|
argument, |
|
request_deserializer, |
|
send_response_callback=send_response) |
|
else: |
|
response_iterator, proceed = _call_behavior( |
|
rpc_event, state, behavior, argument, request_deserializer) |
|
if proceed: |
|
_send_message_callback_to_blocking_iterator_adapter( |
|
rpc_event, state, send_response, response_iterator) |
|
except Exception: |
|
traceback.print_exc() |
|
finally: |
|
cygrpc.uninstall_context() |
|
|
|
|
|
def _is_rpc_state_active(state: _RPCState) -> bool: |
|
return state.client is not _CANCELLED and not state.statused |
|
|
|
|
|
def _send_message_callback_to_blocking_iterator_adapter( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
send_response_callback: Callable[[ResponseType], None], |
|
response_iterator: Iterator[ResponseType]) -> None: |
|
while True: |
|
response, proceed = _take_response_from_response_iterator( |
|
rpc_event, state, response_iterator) |
|
if proceed: |
|
send_response_callback(response) |
|
if not _is_rpc_state_active(state): |
|
break |
|
else: |
|
break |
|
|
|
|
|
def _select_thread_pool_for_behavior( |
|
behavior: ArityAgnosticMethodHandler, |
|
default_thread_pool: futures.ThreadPoolExecutor |
|
) -> futures.ThreadPoolExecutor: |
|
if hasattr(behavior, 'experimental_thread_pool') and isinstance( |
|
behavior.experimental_thread_pool, futures.ThreadPoolExecutor): |
|
return behavior.experimental_thread_pool |
|
else: |
|
return default_thread_pool |
|
|
|
|
|
def _handle_unary_unary( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
method_handler: grpc.RpcMethodHandler, |
|
default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: |
|
unary_request = _unary_request(rpc_event, state, |
|
method_handler.request_deserializer) |
|
thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary, |
|
default_thread_pool) |
|
return thread_pool.submit(_unary_response_in_pool, rpc_event, state, |
|
method_handler.unary_unary, unary_request, |
|
method_handler.request_deserializer, |
|
method_handler.response_serializer) |
|
|
|
|
|
def _handle_unary_stream( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
method_handler: grpc.RpcMethodHandler, |
|
default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: |
|
unary_request = _unary_request(rpc_event, state, |
|
method_handler.request_deserializer) |
|
thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream, |
|
default_thread_pool) |
|
return thread_pool.submit(_stream_response_in_pool, rpc_event, state, |
|
method_handler.unary_stream, unary_request, |
|
method_handler.request_deserializer, |
|
method_handler.response_serializer) |
|
|
|
|
|
def _handle_stream_unary( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
method_handler: grpc.RpcMethodHandler, |
|
default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: |
|
request_iterator = _RequestIterator(state, rpc_event.call, |
|
method_handler.request_deserializer) |
|
thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary, |
|
default_thread_pool) |
|
return thread_pool.submit(_unary_response_in_pool, rpc_event, state, |
|
method_handler.stream_unary, |
|
lambda: request_iterator, |
|
method_handler.request_deserializer, |
|
method_handler.response_serializer) |
|
|
|
|
|
def _handle_stream_stream( |
|
rpc_event: cygrpc.BaseEvent, state: _RPCState, |
|
method_handler: grpc.RpcMethodHandler, |
|
default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: |
|
request_iterator = _RequestIterator(state, rpc_event.call, |
|
method_handler.request_deserializer) |
|
thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream, |
|
default_thread_pool) |
|
return thread_pool.submit(_stream_response_in_pool, rpc_event, state, |
|
method_handler.stream_stream, |
|
lambda: request_iterator, |
|
method_handler.request_deserializer, |
|
method_handler.response_serializer) |
|
|
|
|
|
def _find_method_handler( |
|
rpc_event: cygrpc.BaseEvent, generic_handlers: List[grpc.GenericRpcHandler], |
|
interceptor_pipeline: Optional[_interceptor._ServicePipeline] |
|
) -> Optional[grpc.RpcMethodHandler]: |
|
|
|
def query_handlers( |
|
handler_call_details: _HandlerCallDetails |
|
) -> Optional[grpc.RpcMethodHandler]: |
|
for generic_handler in generic_handlers: |
|
method_handler = generic_handler.service(handler_call_details) |
|
if method_handler is not None: |
|
return method_handler |
|
return None |
|
|
|
handler_call_details = _HandlerCallDetails( |
|
_common.decode(rpc_event.call_details.method), |
|
rpc_event.invocation_metadata) |
|
|
|
if interceptor_pipeline is not None: |
|
return interceptor_pipeline.execute(query_handlers, |
|
handler_call_details) |
|
else: |
|
return query_handlers(handler_call_details) |
|
|
|
|
|
def _reject_rpc(rpc_event: cygrpc.BaseEvent, status: cygrpc.StatusCode, |
|
details: bytes) -> _RPCState: |
|
rpc_state = _RPCState() |
|
operations = ( |
|
_get_initial_metadata_operation(rpc_state, None), |
|
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), |
|
cygrpc.SendStatusFromServerOperation(None, status, details, |
|
_EMPTY_FLAGS), |
|
) |
|
rpc_event.call.start_server_batch(operations, lambda ignored_event: ( |
|
rpc_state, |
|
(), |
|
)) |
|
return rpc_state |
|
|
|
|
|
def _handle_with_method_handler( |
|
rpc_event: cygrpc.BaseEvent, method_handler: grpc.RpcMethodHandler, |
|
thread_pool: futures.ThreadPoolExecutor |
|
) -> Tuple[_RPCState, futures.Future]: |
|
state = _RPCState() |
|
with state.condition: |
|
rpc_event.call.start_server_batch( |
|
(cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), |
|
_receive_close_on_server(state)) |
|
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) |
|
if method_handler.request_streaming: |
|
if method_handler.response_streaming: |
|
return state, _handle_stream_stream(rpc_event, state, |
|
method_handler, thread_pool) |
|
else: |
|
return state, _handle_stream_unary(rpc_event, state, |
|
method_handler, thread_pool) |
|
else: |
|
if method_handler.response_streaming: |
|
return state, _handle_unary_stream(rpc_event, state, |
|
method_handler, thread_pool) |
|
else: |
|
return state, _handle_unary_unary(rpc_event, state, |
|
method_handler, thread_pool) |
|
|
|
|
|
def _handle_call( |
|
rpc_event: cygrpc.BaseEvent, generic_handlers: List[grpc.GenericRpcHandler], |
|
interceptor_pipeline: Optional[_interceptor._ServicePipeline], |
|
thread_pool: futures.ThreadPoolExecutor, concurrency_exceeded: bool |
|
) -> Tuple[Optional[_RPCState], Optional[futures.Future]]: |
|
if not rpc_event.success: |
|
return None, None |
|
if rpc_event.call_details.method is not None: |
|
try: |
|
method_handler = _find_method_handler(rpc_event, generic_handlers, |
|
interceptor_pipeline) |
|
except Exception as exception: |
|
details = 'Exception servicing handler: {}'.format(exception) |
|
_LOGGER.exception(details) |
|
return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown, |
|
b'Error in service handler!'), None |
|
if method_handler is None: |
|
return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented, |
|
b'Method not found!'), None |
|
elif concurrency_exceeded: |
|
return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted, |
|
b'Concurrent RPC limit exceeded!'), None |
|
else: |
|
return _handle_with_method_handler(rpc_event, method_handler, |
|
thread_pool) |
|
else: |
|
return None, None |
|
|
|
|
|
@enum.unique |
|
class _ServerStage(enum.Enum): |
|
STOPPED = 'stopped' |
|
STARTED = 'started' |
|
GRACE = 'grace' |
|
|
|
|
|
class _ServerState(object): |
|
lock: threading.RLock |
|
completion_queue: cygrpc.CompletionQueue |
|
server: cygrpc.Server |
|
generic_handlers: List[grpc.GenericRpcHandler] |
|
interceptor_pipeline: Optional[_interceptor._ServicePipeline] |
|
thread_pool: futures.ThreadPoolExecutor |
|
stage: _ServerStage |
|
termination_event: threading.Event |
|
shutdown_events: List[threading.Event] |
|
maximum_concurrent_rpcs: Optional[int] |
|
active_rpc_count: int |
|
rpc_states: Set[_RPCState] |
|
due: Set[str] |
|
server_deallocated: bool |
|
|
|
|
|
def __init__(self, completion_queue: cygrpc.CompletionQueue, |
|
server: cygrpc.Server, |
|
generic_handlers: Sequence[grpc.GenericRpcHandler], |
|
interceptor_pipeline: Optional[_interceptor._ServicePipeline], |
|
thread_pool: futures.ThreadPoolExecutor, |
|
maximum_concurrent_rpcs: Optional[int]): |
|
self.lock = threading.RLock() |
|
self.completion_queue = completion_queue |
|
self.server = server |
|
self.generic_handlers = list(generic_handlers) |
|
self.interceptor_pipeline = interceptor_pipeline |
|
self.thread_pool = thread_pool |
|
self.stage = _ServerStage.STOPPED |
|
self.termination_event = threading.Event() |
|
self.shutdown_events = [self.termination_event] |
|
self.maximum_concurrent_rpcs = maximum_concurrent_rpcs |
|
self.active_rpc_count = 0 |
|
|
|
|
|
self.rpc_states = set() |
|
self.due = set() |
|
|
|
|
|
self.server_deallocated = False |
|
|
|
|
|
def _add_generic_handlers( |
|
state: _ServerState, |
|
generic_handlers: Iterable[grpc.GenericRpcHandler]) -> None: |
|
with state.lock: |
|
state.generic_handlers.extend(generic_handlers) |
|
|
|
|
|
def _add_insecure_port(state: _ServerState, address: bytes) -> int: |
|
with state.lock: |
|
return state.server.add_http2_port(address) |
|
|
|
|
|
def _add_secure_port(state: _ServerState, address: bytes, |
|
server_credentials: grpc.ServerCredentials) -> int: |
|
with state.lock: |
|
return state.server.add_http2_port(address, |
|
server_credentials._credentials) |
|
|
|
|
|
def _request_call(state: _ServerState) -> None: |
|
state.server.request_call(state.completion_queue, state.completion_queue, |
|
_REQUEST_CALL_TAG) |
|
state.due.add(_REQUEST_CALL_TAG) |
|
|
|
|
|
|
|
def _stop_serving(state: _ServerState) -> bool: |
|
if not state.rpc_states and not state.due: |
|
state.server.destroy() |
|
for shutdown_event in state.shutdown_events: |
|
shutdown_event.set() |
|
state.stage = _ServerStage.STOPPED |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def _on_call_completed(state: _ServerState) -> None: |
|
with state.lock: |
|
state.active_rpc_count -= 1 |
|
|
|
|
|
def _process_event_and_continue(state: _ServerState, |
|
event: cygrpc.BaseEvent) -> bool: |
|
should_continue = True |
|
if event.tag is _SHUTDOWN_TAG: |
|
with state.lock: |
|
state.due.remove(_SHUTDOWN_TAG) |
|
if _stop_serving(state): |
|
should_continue = False |
|
elif event.tag is _REQUEST_CALL_TAG: |
|
with state.lock: |
|
state.due.remove(_REQUEST_CALL_TAG) |
|
concurrency_exceeded = ( |
|
state.maximum_concurrent_rpcs is not None and |
|
state.active_rpc_count >= state.maximum_concurrent_rpcs) |
|
rpc_state, rpc_future = _handle_call(event, state.generic_handlers, |
|
state.interceptor_pipeline, |
|
state.thread_pool, |
|
concurrency_exceeded) |
|
if rpc_state is not None: |
|
state.rpc_states.add(rpc_state) |
|
if rpc_future is not None: |
|
state.active_rpc_count += 1 |
|
rpc_future.add_done_callback( |
|
lambda unused_future: _on_call_completed(state)) |
|
if state.stage is _ServerStage.STARTED: |
|
_request_call(state) |
|
elif _stop_serving(state): |
|
should_continue = False |
|
else: |
|
rpc_state, callbacks = event.tag(event) |
|
for callback in callbacks: |
|
try: |
|
callback() |
|
except Exception: |
|
_LOGGER.exception('Exception calling callback!') |
|
if rpc_state is not None: |
|
with state.lock: |
|
state.rpc_states.remove(rpc_state) |
|
if _stop_serving(state): |
|
should_continue = False |
|
return should_continue |
|
|
|
|
|
def _serve(state: _ServerState) -> None: |
|
while True: |
|
timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S |
|
event = state.completion_queue.poll(timeout) |
|
if state.server_deallocated: |
|
_begin_shutdown_once(state) |
|
if event.completion_type != cygrpc.CompletionType.queue_timeout: |
|
if not _process_event_and_continue(state, event): |
|
return |
|
|
|
|
|
|
|
event = None |
|
|
|
|
|
def _begin_shutdown_once(state: _ServerState) -> None: |
|
with state.lock: |
|
if state.stage is _ServerStage.STARTED: |
|
state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) |
|
state.stage = _ServerStage.GRACE |
|
state.due.add(_SHUTDOWN_TAG) |
|
|
|
|
|
def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event: |
|
with state.lock: |
|
if state.stage is _ServerStage.STOPPED: |
|
shutdown_event = threading.Event() |
|
shutdown_event.set() |
|
return shutdown_event |
|
else: |
|
_begin_shutdown_once(state) |
|
shutdown_event = threading.Event() |
|
state.shutdown_events.append(shutdown_event) |
|
if grace is None: |
|
state.server.cancel_all_calls() |
|
else: |
|
|
|
def cancel_all_calls_after_grace(): |
|
shutdown_event.wait(timeout=grace) |
|
with state.lock: |
|
state.server.cancel_all_calls() |
|
|
|
thread = threading.Thread(target=cancel_all_calls_after_grace) |
|
thread.start() |
|
return shutdown_event |
|
shutdown_event.wait() |
|
return shutdown_event |
|
|
|
|
|
def _start(state: _ServerState) -> None: |
|
with state.lock: |
|
if state.stage is not _ServerStage.STOPPED: |
|
raise ValueError('Cannot start already-started server!') |
|
state.server.start() |
|
state.stage = _ServerStage.STARTED |
|
_request_call(state) |
|
|
|
thread = threading.Thread(target=_serve, args=(state,)) |
|
thread.daemon = True |
|
thread.start() |
|
|
|
|
|
def _validate_generic_rpc_handlers( |
|
generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]) -> None: |
|
for generic_rpc_handler in generic_rpc_handlers: |
|
service_attribute = getattr(generic_rpc_handler, 'service', None) |
|
if service_attribute is None: |
|
raise AttributeError( |
|
'"{}" must conform to grpc.GenericRpcHandler type but does ' |
|
'not have "service" method!'.format(generic_rpc_handler)) |
|
|
|
|
|
def _augment_options( |
|
base_options: Sequence[ChannelArgumentType], |
|
compression: Optional[grpc.Compression] |
|
) -> Sequence[ChannelArgumentType]: |
|
compression_option = _compression.create_channel_option(compression) |
|
return tuple(base_options) + compression_option |
|
|
|
|
|
class _Server(grpc.Server): |
|
_state: _ServerState |
|
|
|
|
|
def __init__(self, thread_pool: futures.ThreadPoolExecutor, |
|
generic_handlers: Sequence[grpc.GenericRpcHandler], |
|
interceptors: Sequence[grpc.ServerInterceptor], |
|
options: Sequence[ChannelArgumentType], |
|
maximum_concurrent_rpcs: Optional[int], |
|
compression: Optional[grpc.Compression], xds: bool): |
|
completion_queue = cygrpc.CompletionQueue() |
|
server = cygrpc.Server(_augment_options(options, compression), xds) |
|
server.register_completion_queue(completion_queue) |
|
self._state = _ServerState(completion_queue, server, generic_handlers, |
|
_interceptor.service_pipeline(interceptors), |
|
thread_pool, maximum_concurrent_rpcs) |
|
|
|
def add_generic_rpc_handlers( |
|
self, |
|
generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]) -> None: |
|
_validate_generic_rpc_handlers(generic_rpc_handlers) |
|
_add_generic_handlers(self._state, generic_rpc_handlers) |
|
|
|
def add_insecure_port(self, address: str) -> int: |
|
return _common.validate_port_binding_result( |
|
address, _add_insecure_port(self._state, _common.encode(address))) |
|
|
|
def add_secure_port(self, address: str, |
|
server_credentials: grpc.ServerCredentials) -> int: |
|
return _common.validate_port_binding_result( |
|
address, |
|
_add_secure_port(self._state, _common.encode(address), |
|
server_credentials)) |
|
|
|
def start(self) -> None: |
|
_start(self._state) |
|
|
|
def wait_for_termination(self, timeout: Optional[float] = None) -> bool: |
|
|
|
|
|
|
|
return _common.wait(self._state.termination_event.wait, |
|
self._state.termination_event.is_set, |
|
timeout=timeout) |
|
|
|
def stop(self, grace: Optional[float]) -> threading.Event: |
|
return _stop(self._state, grace) |
|
|
|
def __del__(self): |
|
if hasattr(self, '_state'): |
|
|
|
|
|
self._state.server_deallocated = True |
|
|
|
|
|
def create_server(thread_pool: futures.ThreadPoolExecutor, |
|
generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], |
|
interceptors: Sequence[grpc.ServerInterceptor], |
|
options: Sequence[ChannelArgumentType], |
|
maximum_concurrent_rpcs: Optional[int], |
|
compression: Optional[grpc.Compression], |
|
xds: bool) -> _Server: |
|
_validate_generic_rpc_handlers(generic_rpc_handlers) |
|
return _Server(thread_pool, generic_rpc_handlers, interceptors, options, |
|
maximum_concurrent_rpcs, compression, xds) |
|
|