| # Copyright 2016 gRPC authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Service-side implementation of gRPC Python.""" |
| |
| from __future__ import annotations |
| |
| import collections |
| from concurrent import futures |
| import contextvars |
| import enum |
| import logging |
| import threading |
| import time |
| import traceback |
| from typing import ( |
| Any, |
| Callable, |
| Iterable, |
| Iterator, |
| List, |
| Mapping, |
| Optional, |
| Sequence, |
| Set, |
| Tuple, |
| Union, |
| ) |
| |
| import grpc # pytype: disable=pyi-error |
| from grpc import _common # pytype: disable=pyi-error |
| from grpc import _compression # pytype: disable=pyi-error |
| from grpc import _interceptor # pytype: disable=pyi-error |
| from grpc._cython import cygrpc |
| from grpc._typing import ArityAgnosticMethodHandler |
| from grpc._typing import ChannelArgumentType |
| from grpc._typing import DeserializingFunction |
| from grpc._typing import MetadataType |
| from grpc._typing import NullaryCallbackType |
| from grpc._typing import ResponseType |
| from grpc._typing import SerializingFunction |
| from grpc._typing import ServerCallbackTag |
| from grpc._typing import ServerTagCallbackType |
| |
| _LOGGER = logging.getLogger(__name__) |
| |
| _SHUTDOWN_TAG = "shutdown" |
| _REQUEST_CALL_TAG = "request_call" |
| |
| _RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server" |
| _SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata" |
| _RECEIVE_MESSAGE_TOKEN = "receive_message" |
| _SEND_MESSAGE_TOKEN = "send_message" |
| _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( |
| "send_initial_metadata * send_message" |
| ) |
| _SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server" |
| _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( |
| "send_initial_metadata * send_status_from_server" |
| ) |
| |
| _OPEN = "open" |
| _CLOSED = "closed" |
| _CANCELLED = "cancelled" |
| |
| _EMPTY_FLAGS = 0 |
| |
| _DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 |
| _INF_TIMEOUT = 1e9 |
| |
| |
| def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes: |
| return request_event.batch_operations[0].message() |
| |
| |
| def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode: |
| cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) |
| return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code |
| |
| |
| def _completion_code(state: _RPCState) -> cygrpc.StatusCode: |
| if state.code is None: |
| return cygrpc.StatusCode.ok |
| else: |
| return _application_code(state.code) |
| |
| |
| def _abortion_code( |
| state: _RPCState, code: cygrpc.StatusCode |
| ) -> cygrpc.StatusCode: |
| if state.code is None: |
| return code |
| else: |
| return _application_code(state.code) |
| |
| |
| def _details(state: _RPCState) -> bytes: |
| return b"" if state.details is None else state.details |
| |
| |
| class _HandlerCallDetails( |
| collections.namedtuple( |
| "_HandlerCallDetails", |
| ( |
| "method", |
| "invocation_metadata", |
| ), |
| ), |
| grpc.HandlerCallDetails, |
| ): |
| pass |
| |
| |
| class _RPCState(object): |
| context: contextvars.Context |
| condition: threading.Condition |
| due = Set[str] |
| request: Any |
| client: str |
| initial_metadata_allowed: bool |
| compression_algorithm: Optional[grpc.Compression] |
| disable_next_compression: bool |
| trailing_metadata: Optional[MetadataType] |
| code: Optional[grpc.StatusCode] |
| details: Optional[bytes] |
| statused: bool |
| rpc_errors: List[Exception] |
| callbacks: Optional[List[NullaryCallbackType]] |
| aborted: bool |
| |
| def __init__(self): |
| self.context = contextvars.Context() |
| self.condition = threading.Condition() |
| self.due = set() |
| self.request = None |
| self.client = _OPEN |
| self.initial_metadata_allowed = True |
| self.compression_algorithm = None |
| self.disable_next_compression = False |
| self.trailing_metadata = None |
| self.code = None |
| self.details = None |
| self.statused = False |
| self.rpc_errors = [] |
| self.callbacks = [] |
| self.aborted = False |
| |
| |
| def _raise_rpc_error(state: _RPCState) -> None: |
| rpc_error = grpc.RpcError() |
| state.rpc_errors.append(rpc_error) |
| raise rpc_error |
| |
| |
| def _possibly_finish_call( |
| state: _RPCState, token: str |
| ) -> ServerTagCallbackType: |
| state.due.remove(token) |
| if not _is_rpc_state_active(state) and not state.due: |
| callbacks = state.callbacks |
| state.callbacks = None |
| return state, callbacks |
| else: |
| return None, () |
| |
| |
| def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag: |
| def send_status_from_server(unused_send_status_from_server_event): |
| with state.condition: |
| return _possibly_finish_call(state, token) |
| |
| return send_status_from_server |
| |
| |
| def _get_initial_metadata( |
| state: _RPCState, metadata: Optional[MetadataType] |
| ) -> Optional[MetadataType]: |
| with state.condition: |
| if state.compression_algorithm: |
| compression_metadata = ( |
| _compression.compression_algorithm_to_metadata( |
| state.compression_algorithm |
| ), |
| ) |
| if metadata is None: |
| return compression_metadata |
| else: |
| return compression_metadata + tuple(metadata) |
| else: |
| return metadata |
| |
| |
| def _get_initial_metadata_operation( |
| state: _RPCState, metadata: Optional[MetadataType] |
| ) -> cygrpc.Operation: |
| operation = cygrpc.SendInitialMetadataOperation( |
| _get_initial_metadata(state, metadata), _EMPTY_FLAGS |
| ) |
| return operation |
| |
| |
| def _abort( |
| state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes |
| ) -> None: |
| if state.client is not _CANCELLED: |
| effective_code = _abortion_code(state, code) |
| effective_details = details if state.details is None else state.details |
| if state.initial_metadata_allowed: |
| operations = ( |
| _get_initial_metadata_operation(state, None), |
| cygrpc.SendStatusFromServerOperation( |
| state.trailing_metadata, |
| effective_code, |
| effective_details, |
| _EMPTY_FLAGS, |
| ), |
| ) |
| token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN |
| else: |
| operations = ( |
| cygrpc.SendStatusFromServerOperation( |
| state.trailing_metadata, |
| effective_code, |
| effective_details, |
| _EMPTY_FLAGS, |
| ), |
| ) |
| token = _SEND_STATUS_FROM_SERVER_TOKEN |
| call.start_server_batch( |
| operations, _send_status_from_server(state, token) |
| ) |
| state.statused = True |
| state.due.add(token) |
| |
| |
| def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag: |
| def receive_close_on_server(receive_close_on_server_event): |
| with state.condition: |
| if receive_close_on_server_event.batch_operations[0].cancelled(): |
| state.client = _CANCELLED |
| elif state.client is _OPEN: |
| state.client = _CLOSED |
| state.condition.notify_all() |
| return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) |
| |
| return receive_close_on_server |
| |
| |
| def _receive_message( |
| state: _RPCState, |
| call: cygrpc.Call, |
| request_deserializer: Optional[DeserializingFunction], |
| ) -> ServerCallbackTag: |
| def receive_message(receive_message_event): |
| serialized_request = _serialized_request(receive_message_event) |
| if serialized_request is None: |
| with state.condition: |
| if state.client is _OPEN: |
| state.client = _CLOSED |
| state.condition.notify_all() |
| return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) |
| else: |
| request = _common.deserialize( |
| serialized_request, request_deserializer |
| ) |
| with state.condition: |
| if request is None: |
| _abort( |
| state, |
| call, |
| cygrpc.StatusCode.internal, |
| b"Exception deserializing request!", |
| ) |
| else: |
| state.request = request |
| state.condition.notify_all() |
| return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) |
| |
| return receive_message |
| |
| |
| def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag: |
| def send_initial_metadata(unused_send_initial_metadata_event): |
| with state.condition: |
| return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) |
| |
| return send_initial_metadata |
| |
| |
| def _send_message(state: _RPCState, token: str) -> ServerCallbackTag: |
| def send_message(unused_send_message_event): |
| with state.condition: |
| state.condition.notify_all() |
| return _possibly_finish_call(state, token) |
| |
| return send_message |
| |
| |
| class _Context(grpc.ServicerContext): |
| _rpc_event: cygrpc.BaseEvent |
| _state: _RPCState |
| request_deserializer: Optional[DeserializingFunction] |
| |
| def __init__( |
| self, |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| request_deserializer: Optional[DeserializingFunction], |
| ): |
| self._rpc_event = rpc_event |
| self._state = state |
| self._request_deserializer = request_deserializer |
| |
| def is_active(self) -> bool: |
| with self._state.condition: |
| return _is_rpc_state_active(self._state) |
| |
| def time_remaining(self) -> float: |
| return max(self._rpc_event.call_details.deadline - time.time(), 0) |
| |
| def cancel(self) -> None: |
| self._rpc_event.call.cancel() |
| |
| def add_callback(self, callback: NullaryCallbackType) -> bool: |
| with self._state.condition: |
| if self._state.callbacks is None: |
| return False |
| else: |
| self._state.callbacks.append(callback) |
| return True |
| |
| def disable_next_message_compression(self) -> None: |
| with self._state.condition: |
| self._state.disable_next_compression = True |
| |
| def invocation_metadata(self) -> Optional[MetadataType]: |
| return self._rpc_event.invocation_metadata |
| |
| def peer(self) -> str: |
| return _common.decode(self._rpc_event.call.peer()) |
| |
| def peer_identities(self) -> Optional[Sequence[bytes]]: |
| return cygrpc.peer_identities(self._rpc_event.call) |
| |
| def peer_identity_key(self) -> Optional[str]: |
| id_key = cygrpc.peer_identity_key(self._rpc_event.call) |
| return id_key if id_key is None else _common.decode(id_key) |
| |
| def auth_context(self) -> Mapping[str, Sequence[bytes]]: |
| auth_context = cygrpc.auth_context(self._rpc_event.call) |
| auth_context_dict = {} if auth_context is None else auth_context |
| return { |
| _common.decode(key): value |
| for key, value in auth_context_dict.items() |
| } |
| |
| def set_compression(self, compression: grpc.Compression) -> None: |
| with self._state.condition: |
| self._state.compression_algorithm = compression |
| |
| def send_initial_metadata(self, initial_metadata: MetadataType) -> None: |
| with self._state.condition: |
| if self._state.client is _CANCELLED: |
| _raise_rpc_error(self._state) |
| else: |
| if self._state.initial_metadata_allowed: |
| operation = _get_initial_metadata_operation( |
| self._state, initial_metadata |
| ) |
| self._rpc_event.call.start_server_batch( |
| (operation,), _send_initial_metadata(self._state) |
| ) |
| self._state.initial_metadata_allowed = False |
| self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) |
| else: |
| raise ValueError("Initial metadata no longer allowed!") |
| |
| def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: |
| with self._state.condition: |
| self._state.trailing_metadata = trailing_metadata |
| |
| def trailing_metadata(self) -> Optional[MetadataType]: |
| return self._state.trailing_metadata |
| |
| def abort(self, code: grpc.StatusCode, details: str) -> None: |
| # treat OK like other invalid arguments: fail the RPC |
| if code == grpc.StatusCode.OK: |
| _LOGGER.error( |
| "abort() called with StatusCode.OK; returning UNKNOWN" |
| ) |
| code = grpc.StatusCode.UNKNOWN |
| details = "" |
| with self._state.condition: |
| self._state.code = code |
| self._state.details = _common.encode(details) |
| self._state.aborted = True |
| raise Exception() |
| |
| def abort_with_status(self, status: grpc.Status) -> None: |
| self._state.trailing_metadata = status.trailing_metadata |
| self.abort(status.code, status.details) |
| |
| def set_code(self, code: grpc.StatusCode) -> None: |
| with self._state.condition: |
| self._state.code = code |
| |
| def code(self) -> grpc.StatusCode: |
| return self._state.code |
| |
| def set_details(self, details: str) -> None: |
| with self._state.condition: |
| self._state.details = _common.encode(details) |
| |
| def details(self) -> bytes: |
| return self._state.details |
| |
| def _finalize_state(self) -> None: |
| pass |
| |
| |
| class _RequestIterator(object): |
| _state: _RPCState |
| _call: cygrpc.Call |
| _request_deserializer: Optional[DeserializingFunction] |
| |
| def __init__( |
| self, |
| state: _RPCState, |
| call: cygrpc.Call, |
| request_deserializer: Optional[DeserializingFunction], |
| ): |
| self._state = state |
| self._call = call |
| self._request_deserializer = request_deserializer |
| |
| def _raise_or_start_receive_message(self) -> None: |
| if self._state.client is _CANCELLED: |
| _raise_rpc_error(self._state) |
| elif not _is_rpc_state_active(self._state): |
| raise StopIteration() |
| else: |
| self._call.start_server_batch( |
| (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), |
| _receive_message( |
| self._state, self._call, self._request_deserializer |
| ), |
| ) |
| self._state.due.add(_RECEIVE_MESSAGE_TOKEN) |
| |
| def _look_for_request(self) -> Any: |
| if self._state.client is _CANCELLED: |
| _raise_rpc_error(self._state) |
| elif ( |
| self._state.request is None |
| and _RECEIVE_MESSAGE_TOKEN not in self._state.due |
| ): |
| raise StopIteration() |
| else: |
| request = self._state.request |
| self._state.request = None |
| return request |
| |
| raise AssertionError() # should never run |
| |
| def _next(self) -> Any: |
| with self._state.condition: |
| self._raise_or_start_receive_message() |
| while True: |
| self._state.condition.wait() |
| request = self._look_for_request() |
| if request is not None: |
| return request |
| |
| def __iter__(self) -> _RequestIterator: |
| return self |
| |
| def __next__(self) -> Any: |
| return self._next() |
| |
| def next(self) -> Any: |
| return self._next() |
| |
| |
| def _unary_request( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| request_deserializer: Optional[DeserializingFunction], |
| ) -> Callable[[], Any]: |
| def unary_request(): |
| with state.condition: |
| if not _is_rpc_state_active(state): |
| return None |
| else: |
| rpc_event.call.start_server_batch( |
| (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), |
| _receive_message( |
| state, rpc_event.call, request_deserializer |
| ), |
| ) |
| state.due.add(_RECEIVE_MESSAGE_TOKEN) |
| while True: |
| state.condition.wait() |
| if state.request is None: |
| if state.client is _CLOSED: |
| details = '"{}" requires exactly one request message.'.format( |
| rpc_event.call_details.method |
| ) |
| _abort( |
| state, |
| rpc_event.call, |
| cygrpc.StatusCode.unimplemented, |
| _common.encode(details), |
| ) |
| return None |
| elif state.client is _CANCELLED: |
| return None |
| else: |
| request = state.request |
| state.request = None |
| return request |
| |
| return unary_request |
| |
| |
| def _call_behavior( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| behavior: ArityAgnosticMethodHandler, |
| argument: Any, |
| request_deserializer: Optional[DeserializingFunction], |
| send_response_callback: Optional[Callable[[ResponseType], None]] = None, |
| ) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]: |
| from grpc import _create_servicer_context # pytype: disable=pyi-error |
| |
| with _create_servicer_context( |
| rpc_event, state, request_deserializer |
| ) as context: |
| try: |
| response_or_iterator = None |
| if send_response_callback is not None: |
| response_or_iterator = behavior( |
| argument, context, send_response_callback |
| ) |
| else: |
| response_or_iterator = behavior(argument, context) |
| return response_or_iterator, True |
| except Exception as exception: # pylint: disable=broad-except |
| with state.condition: |
| if state.aborted: |
| _abort( |
| state, |
| rpc_event.call, |
| cygrpc.StatusCode.unknown, |
| b"RPC Aborted", |
| ) |
| elif exception not in state.rpc_errors: |
| try: |
| details = "Exception calling application: {}".format( |
| exception |
| ) |
| except Exception: # pylint: disable=broad-except |
| details = ( |
| "Calling application raised unprintable Exception!" |
| ) |
| _LOGGER.exception( |
| traceback.format_exception( |
| type(exception), |
| exception, |
| exception.__traceback__, |
| ) |
| ) |
| traceback.print_exc() |
| _LOGGER.exception(details) |
| _abort( |
| state, |
| rpc_event.call, |
| cygrpc.StatusCode.unknown, |
| _common.encode(details), |
| ) |
| return None, False |
| |
| |
| def _take_response_from_response_iterator( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| response_iterator: Iterator[ResponseType], |
| ) -> Tuple[ResponseType, bool]: |
| try: |
| return next(response_iterator), True |
| except StopIteration: |
| return None, True |
| except Exception as exception: # pylint: disable=broad-except |
| with state.condition: |
| if state.aborted: |
| _abort( |
| state, |
| rpc_event.call, |
| cygrpc.StatusCode.unknown, |
| b"RPC Aborted", |
| ) |
| elif exception not in state.rpc_errors: |
| details = "Exception iterating responses: {}".format(exception) |
| _LOGGER.exception(details) |
| _abort( |
| state, |
| rpc_event.call, |
| cygrpc.StatusCode.unknown, |
| _common.encode(details), |
| ) |
| return None, False |
| |
| |
| def _serialize_response( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| response: Any, |
| response_serializer: Optional[SerializingFunction], |
| ) -> Optional[bytes]: |
| serialized_response = _common.serialize(response, response_serializer) |
| if serialized_response is None: |
| with state.condition: |
| _abort( |
| state, |
| rpc_event.call, |
| cygrpc.StatusCode.internal, |
| b"Failed to serialize response!", |
| ) |
| return None |
| else: |
| return serialized_response |
| |
| |
| def _get_send_message_op_flags_from_state( |
| state: _RPCState, |
| ) -> Union[int, cygrpc.WriteFlag]: |
| if state.disable_next_compression: |
| return cygrpc.WriteFlag.no_compress |
| else: |
| return _EMPTY_FLAGS |
| |
| |
| def _reset_per_message_state(state: _RPCState) -> None: |
| with state.condition: |
| state.disable_next_compression = False |
| |
| |
| def _send_response( |
| rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes |
| ) -> bool: |
| with state.condition: |
| if not _is_rpc_state_active(state): |
| return False |
| else: |
| if state.initial_metadata_allowed: |
| operations = ( |
| _get_initial_metadata_operation(state, None), |
| cygrpc.SendMessageOperation( |
| serialized_response, |
| _get_send_message_op_flags_from_state(state), |
| ), |
| ) |
| state.initial_metadata_allowed = False |
| token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN |
| else: |
| operations = ( |
| cygrpc.SendMessageOperation( |
| serialized_response, |
| _get_send_message_op_flags_from_state(state), |
| ), |
| ) |
| token = _SEND_MESSAGE_TOKEN |
| rpc_event.call.start_server_batch( |
| operations, _send_message(state, token) |
| ) |
| state.due.add(token) |
| _reset_per_message_state(state) |
| while True: |
| state.condition.wait() |
| if token not in state.due: |
| return _is_rpc_state_active(state) |
| |
| |
| def _status( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| serialized_response: Optional[bytes], |
| ) -> None: |
| with state.condition: |
| if state.client is not _CANCELLED: |
| code = _completion_code(state) |
| details = _details(state) |
| operations = [ |
| cygrpc.SendStatusFromServerOperation( |
| state.trailing_metadata, code, details, _EMPTY_FLAGS |
| ), |
| ] |
| if state.initial_metadata_allowed: |
| operations.append(_get_initial_metadata_operation(state, None)) |
| if serialized_response is not None: |
| operations.append( |
| cygrpc.SendMessageOperation( |
| serialized_response, |
| _get_send_message_op_flags_from_state(state), |
| ) |
| ) |
| rpc_event.call.start_server_batch( |
| operations, |
| _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN), |
| ) |
| state.statused = True |
| _reset_per_message_state(state) |
| state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) |
| |
| |
| def _unary_response_in_pool( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| behavior: ArityAgnosticMethodHandler, |
| argument_thunk: Callable[[], Any], |
| request_deserializer: Optional[SerializingFunction], |
| response_serializer: Optional[SerializingFunction], |
| ) -> None: |
| cygrpc.install_context_from_request_call_event(rpc_event) |
| |
| try: |
| argument = argument_thunk() |
| if argument is not None: |
| response, proceed = _call_behavior( |
| rpc_event, state, behavior, argument, request_deserializer |
| ) |
| if proceed: |
| serialized_response = _serialize_response( |
| rpc_event, state, response, response_serializer |
| ) |
| if serialized_response is not None: |
| _status(rpc_event, state, serialized_response) |
| except Exception: # pylint: disable=broad-except |
| traceback.print_exc() |
| finally: |
| cygrpc.uninstall_context() |
| |
| |
| def _stream_response_in_pool( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| behavior: ArityAgnosticMethodHandler, |
| argument_thunk: Callable[[], Any], |
| request_deserializer: Optional[DeserializingFunction], |
| response_serializer: Optional[SerializingFunction], |
| ) -> None: |
| cygrpc.install_context_from_request_call_event(rpc_event) |
| |
| def send_response(response: Any) -> None: |
| if response is None: |
| _status(rpc_event, state, None) |
| else: |
| serialized_response = _serialize_response( |
| rpc_event, state, response, response_serializer |
| ) |
| if serialized_response is not None: |
| _send_response(rpc_event, state, serialized_response) |
| |
| try: |
| argument = argument_thunk() |
| if argument is not None: |
| if ( |
| hasattr(behavior, "experimental_non_blocking") |
| and behavior.experimental_non_blocking |
| ): |
| _call_behavior( |
| rpc_event, |
| state, |
| behavior, |
| argument, |
| request_deserializer, |
| send_response_callback=send_response, |
| ) |
| else: |
| response_iterator, proceed = _call_behavior( |
| rpc_event, state, behavior, argument, request_deserializer |
| ) |
| if proceed: |
| _send_message_callback_to_blocking_iterator_adapter( |
| rpc_event, state, send_response, response_iterator |
| ) |
| except Exception: # pylint: disable=broad-except |
| traceback.print_exc() |
| finally: |
| cygrpc.uninstall_context() |
| |
| |
| def _is_rpc_state_active(state: _RPCState) -> bool: |
| return state.client is not _CANCELLED and not state.statused |
| |
| |
| def _send_message_callback_to_blocking_iterator_adapter( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| send_response_callback: Callable[[ResponseType], None], |
| response_iterator: Iterator[ResponseType], |
| ) -> None: |
| while True: |
| response, proceed = _take_response_from_response_iterator( |
| rpc_event, state, response_iterator |
| ) |
| if proceed: |
| send_response_callback(response) |
| if not _is_rpc_state_active(state): |
| break |
| else: |
| break |
| |
| |
| def _select_thread_pool_for_behavior( |
| behavior: ArityAgnosticMethodHandler, |
| default_thread_pool: futures.ThreadPoolExecutor, |
| ) -> futures.ThreadPoolExecutor: |
| if hasattr(behavior, "experimental_thread_pool") and isinstance( |
| behavior.experimental_thread_pool, futures.ThreadPoolExecutor |
| ): |
| return behavior.experimental_thread_pool |
| else: |
| return default_thread_pool |
| |
| |
| def _handle_unary_unary( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| method_handler: grpc.RpcMethodHandler, |
| default_thread_pool: futures.ThreadPoolExecutor, |
| ) -> futures.Future: |
| unary_request = _unary_request( |
| rpc_event, state, method_handler.request_deserializer |
| ) |
| thread_pool = _select_thread_pool_for_behavior( |
| method_handler.unary_unary, default_thread_pool |
| ) |
| return thread_pool.submit( |
| state.context.run, |
| _unary_response_in_pool, |
| rpc_event, |
| state, |
| method_handler.unary_unary, |
| unary_request, |
| method_handler.request_deserializer, |
| method_handler.response_serializer, |
| ) |
| |
| |
| def _handle_unary_stream( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| method_handler: grpc.RpcMethodHandler, |
| default_thread_pool: futures.ThreadPoolExecutor, |
| ) -> futures.Future: |
| unary_request = _unary_request( |
| rpc_event, state, method_handler.request_deserializer |
| ) |
| thread_pool = _select_thread_pool_for_behavior( |
| method_handler.unary_stream, default_thread_pool |
| ) |
| return thread_pool.submit( |
| state.context.run, |
| _stream_response_in_pool, |
| rpc_event, |
| state, |
| method_handler.unary_stream, |
| unary_request, |
| method_handler.request_deserializer, |
| method_handler.response_serializer, |
| ) |
| |
| |
| def _handle_stream_unary( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| method_handler: grpc.RpcMethodHandler, |
| default_thread_pool: futures.ThreadPoolExecutor, |
| ) -> futures.Future: |
| request_iterator = _RequestIterator( |
| state, rpc_event.call, method_handler.request_deserializer |
| ) |
| thread_pool = _select_thread_pool_for_behavior( |
| method_handler.stream_unary, default_thread_pool |
| ) |
| return thread_pool.submit( |
| state.context.run, |
| _unary_response_in_pool, |
| rpc_event, |
| state, |
| method_handler.stream_unary, |
| lambda: request_iterator, |
| method_handler.request_deserializer, |
| method_handler.response_serializer, |
| ) |
| |
| |
| def _handle_stream_stream( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| method_handler: grpc.RpcMethodHandler, |
| default_thread_pool: futures.ThreadPoolExecutor, |
| ) -> futures.Future: |
| request_iterator = _RequestIterator( |
| state, rpc_event.call, method_handler.request_deserializer |
| ) |
| thread_pool = _select_thread_pool_for_behavior( |
| method_handler.stream_stream, default_thread_pool |
| ) |
| return thread_pool.submit( |
| state.context.run, |
| _stream_response_in_pool, |
| rpc_event, |
| state, |
| method_handler.stream_stream, |
| lambda: request_iterator, |
| method_handler.request_deserializer, |
| method_handler.response_serializer, |
| ) |
| |
| |
| def _find_method_handler( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| generic_handlers: List[grpc.GenericRpcHandler], |
| interceptor_pipeline: Optional[_interceptor._ServicePipeline], |
| ) -> Optional[grpc.RpcMethodHandler]: |
| def query_handlers( |
| handler_call_details: _HandlerCallDetails, |
| ) -> Optional[grpc.RpcMethodHandler]: |
| for generic_handler in generic_handlers: |
| method_handler = generic_handler.service(handler_call_details) |
| if method_handler is not None: |
| return method_handler |
| return None |
| |
| handler_call_details = _HandlerCallDetails( |
| _common.decode(rpc_event.call_details.method), |
| rpc_event.invocation_metadata, |
| ) |
| |
| if interceptor_pipeline is not None: |
| return state.context.run( |
| interceptor_pipeline.execute, query_handlers, handler_call_details |
| ) |
| else: |
| return state.context.run(query_handlers, handler_call_details) |
| |
| |
| def _reject_rpc( |
| rpc_event: cygrpc.BaseEvent, |
| rpc_state: _RPCState, |
| status: cygrpc.StatusCode, |
| details: bytes, |
| ): |
| operations = ( |
| _get_initial_metadata_operation(rpc_state, None), |
| cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), |
| cygrpc.SendStatusFromServerOperation( |
| None, status, details, _EMPTY_FLAGS |
| ), |
| ) |
| rpc_event.call.start_server_batch( |
| operations, |
| lambda ignored_event: ( |
| rpc_state, |
| (), |
| ), |
| ) |
| |
| |
| def _handle_with_method_handler( |
| rpc_event: cygrpc.BaseEvent, |
| state: _RPCState, |
| method_handler: grpc.RpcMethodHandler, |
| thread_pool: futures.ThreadPoolExecutor, |
| ) -> futures.Future: |
| with state.condition: |
| rpc_event.call.start_server_batch( |
| (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), |
| _receive_close_on_server(state), |
| ) |
| state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) |
| if method_handler.request_streaming: |
| if method_handler.response_streaming: |
| return _handle_stream_stream( |
| rpc_event, state, method_handler, thread_pool |
| ) |
| else: |
| return _handle_stream_unary( |
| rpc_event, state, method_handler, thread_pool |
| ) |
| else: |
| if method_handler.response_streaming: |
| return _handle_unary_stream( |
| rpc_event, state, method_handler, thread_pool |
| ) |
| else: |
| return _handle_unary_unary( |
| rpc_event, state, method_handler, thread_pool |
| ) |
| |
| |
| def _handle_call( |
| rpc_event: cygrpc.BaseEvent, |
| generic_handlers: List[grpc.GenericRpcHandler], |
| interceptor_pipeline: Optional[_interceptor._ServicePipeline], |
| thread_pool: futures.ThreadPoolExecutor, |
| concurrency_exceeded: bool, |
| ) -> Tuple[Optional[_RPCState], Optional[futures.Future]]: |
| if not rpc_event.success: |
| return None, None |
| if rpc_event.call_details.method is not None: |
| rpc_state = _RPCState() |
| try: |
| method_handler = _find_method_handler( |
| rpc_event, rpc_state, generic_handlers, interceptor_pipeline |
| ) |
| except Exception as exception: # pylint: disable=broad-except |
| details = "Exception servicing handler: {}".format(exception) |
| _LOGGER.exception(details) |
| _reject_rpc( |
| rpc_event, |
| rpc_state, |
| cygrpc.StatusCode.unknown, |
| b"Error in service handler!", |
| ) |
| return rpc_state, None |
| if method_handler is None: |
| _reject_rpc( |
| rpc_event, |
| rpc_state, |
| cygrpc.StatusCode.unimplemented, |
| b"Method not found!", |
| ) |
| return rpc_state, None |
| elif concurrency_exceeded: |
| _reject_rpc( |
| rpc_event, |
| rpc_state, |
| cygrpc.StatusCode.resource_exhausted, |
| b"Concurrent RPC limit exceeded!", |
| ) |
| return rpc_state, None |
| else: |
| return ( |
| rpc_state, |
| _handle_with_method_handler( |
| rpc_event, rpc_state, method_handler, thread_pool |
| ), |
| ) |
| else: |
| return None, None |
| |
| |
| @enum.unique |
| class _ServerStage(enum.Enum): |
| STOPPED = "stopped" |
| STARTED = "started" |
| GRACE = "grace" |
| |
| |
| class _ServerState(object): |
| lock: threading.RLock |
| completion_queue: cygrpc.CompletionQueue |
| server: cygrpc.Server |
| generic_handlers: List[grpc.GenericRpcHandler] |
| interceptor_pipeline: Optional[_interceptor._ServicePipeline] |
| thread_pool: futures.ThreadPoolExecutor |
| stage: _ServerStage |
| termination_event: threading.Event |
| shutdown_events: List[threading.Event] |
| maximum_concurrent_rpcs: Optional[int] |
| active_rpc_count: int |
| rpc_states: Set[_RPCState] |
| due: Set[str] |
| server_deallocated: bool |
| |
| # pylint: disable=too-many-arguments |
| def __init__( |
| self, |
| completion_queue: cygrpc.CompletionQueue, |
| server: cygrpc.Server, |
| generic_handlers: Sequence[grpc.GenericRpcHandler], |
| interceptor_pipeline: Optional[_interceptor._ServicePipeline], |
| thread_pool: futures.ThreadPoolExecutor, |
| maximum_concurrent_rpcs: Optional[int], |
| ): |
| self.lock = threading.RLock() |
| self.completion_queue = completion_queue |
| self.server = server |
| self.generic_handlers = list(generic_handlers) |
| self.interceptor_pipeline = interceptor_pipeline |
| self.thread_pool = thread_pool |
| self.stage = _ServerStage.STOPPED |
| self.termination_event = threading.Event() |
| self.shutdown_events = [self.termination_event] |
| self.maximum_concurrent_rpcs = maximum_concurrent_rpcs |
| self.active_rpc_count = 0 |
| |
| # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. |
| self.rpc_states = set() |
| self.due = set() |
| |
| # A "volatile" flag to interrupt the daemon serving thread |
| self.server_deallocated = False |
| |
| |
| def _add_generic_handlers( |
| state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler] |
| ) -> None: |
| with state.lock: |
| state.generic_handlers.extend(generic_handlers) |
| |
| |
| def _add_insecure_port(state: _ServerState, address: bytes) -> int: |
| with state.lock: |
| return state.server.add_http2_port(address) |
| |
| |
| def _add_secure_port( |
| state: _ServerState, |
| address: bytes, |
| server_credentials: grpc.ServerCredentials, |
| ) -> int: |
| with state.lock: |
| return state.server.add_http2_port( |
| address, server_credentials._credentials |
| ) |
| |
| |
| def _request_call(state: _ServerState) -> None: |
| state.server.request_call( |
| state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG |
| ) |
| state.due.add(_REQUEST_CALL_TAG) |
| |
| |
| # TODO(https://github.com/grpc/grpc/issues/6597): delete this function. |
| def _stop_serving(state: _ServerState) -> bool: |
| if not state.rpc_states and not state.due: |
| state.server.destroy() |
| for shutdown_event in state.shutdown_events: |
| shutdown_event.set() |
| state.stage = _ServerStage.STOPPED |
| return True |
| else: |
| return False |
| |
| |
| def _on_call_completed(state: _ServerState) -> None: |
| with state.lock: |
| state.active_rpc_count -= 1 |
| |
| |
| def _process_event_and_continue( |
| state: _ServerState, event: cygrpc.BaseEvent |
| ) -> bool: |
| should_continue = True |
| if event.tag is _SHUTDOWN_TAG: |
| with state.lock: |
| state.due.remove(_SHUTDOWN_TAG) |
| if _stop_serving(state): |
| should_continue = False |
| elif event.tag is _REQUEST_CALL_TAG: |
| with state.lock: |
| state.due.remove(_REQUEST_CALL_TAG) |
| concurrency_exceeded = ( |
| state.maximum_concurrent_rpcs is not None |
| and state.active_rpc_count >= state.maximum_concurrent_rpcs |
| ) |
| rpc_state, rpc_future = _handle_call( |
| event, |
| state.generic_handlers, |
| state.interceptor_pipeline, |
| state.thread_pool, |
| concurrency_exceeded, |
| ) |
| if rpc_state is not None: |
| state.rpc_states.add(rpc_state) |
| if rpc_future is not None: |
| state.active_rpc_count += 1 |
| rpc_future.add_done_callback( |
| lambda unused_future: _on_call_completed(state) |
| ) |
| if state.stage is _ServerStage.STARTED: |
| _request_call(state) |
| elif _stop_serving(state): |
| should_continue = False |
| else: |
| rpc_state, callbacks = event.tag(event) |
| for callback in callbacks: |
| try: |
| callback() |
| except Exception: # pylint: disable=broad-except |
| _LOGGER.exception("Exception calling callback!") |
| if rpc_state is not None: |
| with state.lock: |
| state.rpc_states.remove(rpc_state) |
| if _stop_serving(state): |
| should_continue = False |
| return should_continue |
| |
| |
| def _serve(state: _ServerState) -> None: |
| while True: |
| timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S |
| event = state.completion_queue.poll(timeout) |
| if state.server_deallocated: |
| _begin_shutdown_once(state) |
| if event.completion_type != cygrpc.CompletionType.queue_timeout: |
| if not _process_event_and_continue(state, event): |
| return |
| # We want to force the deletion of the previous event |
| # ~before~ we poll again; if the event has a reference |
| # to a shutdown Call object, this can induce spinlock. |
| event = None |
| |
| |
| def _begin_shutdown_once(state: _ServerState) -> None: |
| with state.lock: |
| if state.stage is _ServerStage.STARTED: |
| state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) |
| state.stage = _ServerStage.GRACE |
| state.due.add(_SHUTDOWN_TAG) |
| |
| |
| def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event: |
| with state.lock: |
| if state.stage is _ServerStage.STOPPED: |
| shutdown_event = threading.Event() |
| shutdown_event.set() |
| return shutdown_event |
| else: |
| _begin_shutdown_once(state) |
| shutdown_event = threading.Event() |
| state.shutdown_events.append(shutdown_event) |
| if grace is None: |
| state.server.cancel_all_calls() |
| else: |
| |
| def cancel_all_calls_after_grace(): |
| shutdown_event.wait(timeout=grace) |
| with state.lock: |
| state.server.cancel_all_calls() |
| |
| thread = threading.Thread(target=cancel_all_calls_after_grace) |
| thread.start() |
| return shutdown_event |
| shutdown_event.wait() |
| return shutdown_event |
| |
| |
| def _start(state: _ServerState) -> None: |
| with state.lock: |
| if state.stage is not _ServerStage.STOPPED: |
| raise ValueError("Cannot start already-started server!") |
| state.server.start() |
| state.stage = _ServerStage.STARTED |
| _request_call(state) |
| thread = threading.Thread(target=_serve, args=(state,)) |
| thread.daemon = True |
| thread.start() |
| |
| |
| def _validate_generic_rpc_handlers( |
| generic_rpc_handlers: Iterable[grpc.GenericRpcHandler], |
| ) -> None: |
| for generic_rpc_handler in generic_rpc_handlers: |
| service_attribute = getattr(generic_rpc_handler, "service", None) |
| if service_attribute is None: |
| raise AttributeError( |
| '"{}" must conform to grpc.GenericRpcHandler type but does ' |
| 'not have "service" method!'.format(generic_rpc_handler) |
| ) |
| |
| |
| def _augment_options( |
| base_options: Sequence[ChannelArgumentType], |
| compression: Optional[grpc.Compression], |
| ) -> Sequence[ChannelArgumentType]: |
| compression_option = _compression.create_channel_option(compression) |
| return tuple(base_options) + compression_option |
| |
| |
| class _Server(grpc.Server): |
| _state: _ServerState |
| |
| # pylint: disable=too-many-arguments |
| def __init__( |
| self, |
| thread_pool: futures.ThreadPoolExecutor, |
| generic_handlers: Sequence[grpc.GenericRpcHandler], |
| interceptors: Sequence[grpc.ServerInterceptor], |
| options: Sequence[ChannelArgumentType], |
| maximum_concurrent_rpcs: Optional[int], |
| compression: Optional[grpc.Compression], |
| xds: bool, |
| ): |
| completion_queue = cygrpc.CompletionQueue() |
| server = cygrpc.Server(_augment_options(options, compression), xds) |
| server.register_completion_queue(completion_queue) |
| self._state = _ServerState( |
| completion_queue, |
| server, |
| generic_handlers, |
| _interceptor.service_pipeline(interceptors), |
| thread_pool, |
| maximum_concurrent_rpcs, |
| ) |
| |
| def add_generic_rpc_handlers( |
| self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler] |
| ) -> None: |
| _validate_generic_rpc_handlers(generic_rpc_handlers) |
| _add_generic_handlers(self._state, generic_rpc_handlers) |
| |
| def add_insecure_port(self, address: str) -> int: |
| return _common.validate_port_binding_result( |
| address, _add_insecure_port(self._state, _common.encode(address)) |
| ) |
| |
| def add_secure_port( |
| self, address: str, server_credentials: grpc.ServerCredentials |
| ) -> int: |
| return _common.validate_port_binding_result( |
| address, |
| _add_secure_port( |
| self._state, _common.encode(address), server_credentials |
| ), |
| ) |
| |
| def start(self) -> None: |
| _start(self._state) |
| |
| def wait_for_termination(self, timeout: Optional[float] = None) -> bool: |
| # NOTE(https://bugs.python.org/issue35935) |
| # Remove this workaround once threading.Event.wait() is working with |
| # CTRL+C across platforms. |
| return _common.wait( |
| self._state.termination_event.wait, |
| self._state.termination_event.is_set, |
| timeout=timeout, |
| ) |
| |
| def stop(self, grace: Optional[float]) -> threading.Event: |
| return _stop(self._state, grace) |
| |
| def __del__(self): |
| if hasattr(self, "_state"): |
| # We can not grab a lock in __del__(), so set a flag to signal the |
| # serving daemon thread (if it exists) to initiate shutdown. |
| self._state.server_deallocated = True |
| |
| |
| def create_server( |
| thread_pool: futures.ThreadPoolExecutor, |
| generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], |
| interceptors: Sequence[grpc.ServerInterceptor], |
| options: Sequence[ChannelArgumentType], |
| maximum_concurrent_rpcs: Optional[int], |
| compression: Optional[grpc.Compression], |
| xds: bool, |
| ) -> _Server: |
| _validate_generic_rpc_handlers(generic_rpc_handlers) |
| return _Server( |
| thread_pool, |
| generic_rpc_handlers, |
| interceptors, |
| options, |
| maximum_concurrent_rpcs, |
| compression, |
| xds, |
| ) |