fix: resolve issue handling protobuf responses in rest streaming (#604)
* fix: resolve issue handling protobuf responses in rest streaming
* raise ValueError if response_message_cls is not a subclass of proto.Message or google.protobuf.message.Message
* remove response_type from pytest.mark.parametrize
* 🦉 Updates from OwlBot post-processor
See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md
* add test for ValueError in response_iterator._grab()
---------
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py
index f91381c..3f5b6b0 100644
--- a/google/api_core/rest_streaming.py
+++ b/google/api_core/rest_streaming.py
@@ -16,9 +16,12 @@
from collections import deque
import string
-from typing import Deque
+from typing import Deque, Union
+import proto
import requests
+import google.protobuf.message
+from google.protobuf.json_format import Parse
class ResponseIterator:
@@ -26,11 +29,18 @@
Args:
response (requests.Response): An API response object.
- response_message_cls (Callable[proto.Message]): A proto
+ response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
class expected to be returned from an API.
+
+ Raises:
+ ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
"""
- def __init__(self, response: requests.Response, response_message_cls):
+ def __init__(
+ self,
+ response: requests.Response,
+ response_message_cls: Union[proto.Message, google.protobuf.message.Message],
+ ):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
@@ -107,7 +117,14 @@
def _grab(self):
# Add extra quotes to make json.loads happy.
- return self._response_message_cls.from_json(self._ready_objs.popleft())
+ if issubclass(self._response_message_cls, proto.Message):
+ return self._response_message_cls.from_json(self._ready_objs.popleft())
+ elif issubclass(self._response_message_cls, google.protobuf.message.Message):
+ return Parse(self._ready_objs.popleft(), self._response_message_cls())
+ else:
+ raise ValueError(
+ "Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
+ )
def __iter__(self):
return self
diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py
index a44c83c..b532eb1 100644
--- a/tests/unit/test_rest_streaming.py
+++ b/tests/unit/test_rest_streaming.py
@@ -24,8 +24,11 @@
import requests
from google.api_core import rest_streaming
+from google.api import http_pb2
+from google.api import httpbody_pb2
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
+from google.protobuf.json_format import MessageToJson
__protobuf__ = proto.module(package=__name__)
@@ -98,7 +101,10 @@
# json.dumps returns a string surrounded with quotes that need to be stripped
# in order to be an actual JSON.
json_responses = [
- self._response_message_cls.to_json(r).strip('"') for r in responses
+ self._response_message_cls.to_json(r).strip('"')
+ if issubclass(self._response_message_cls, proto.Message)
+ else MessageToJson(r).strip('"')
+ for r in responses
]
logging.info(f"Sending JSON stream: {json_responses}")
ret_val = "[{}]".format(",".join(json_responses))
@@ -114,103 +120,220 @@
)
[email protected]("random_split", [False])
-def test_next_simple(random_split):
- responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
[email protected](
+ "random_split,resp_message_is_proto_plus",
+ [(False, True), (False, False)],
+)
+def test_next_simple(random_split, resp_message_is_proto_plus):
+ if resp_message_is_proto_plus:
+ response_type = EchoResponse
+ responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
+ else:
+ response_type = httpbody_pb2.HttpBody
+ responses = [
+ httpbody_pb2.HttpBody(content_type="hello world"),
+ httpbody_pb2.HttpBody(content_type="yes"),
+ ]
+
resp = ResponseMock(
- responses=responses, random_split=random_split, response_cls=EchoResponse
+ responses=responses, random_split=random_split, response_cls=response_type
)
- itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses
[email protected]("random_split", [True, False])
-def test_next_nested(random_split):
- responses = [
- Song(title="some song", composer=Composer(given_name="some name")),
- Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
- ]
[email protected](
+ "random_split,resp_message_is_proto_plus",
+ [
+ (True, True),
+ (False, True),
+ (True, False),
+ (False, False),
+ ],
+)
+def test_next_nested(random_split, resp_message_is_proto_plus):
+ if resp_message_is_proto_plus:
+ response_type = Song
+ responses = [
+ Song(title="some song", composer=Composer(given_name="some name")),
+ Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
+ ]
+ else:
+ # Although `http_pb2.HttpRule`` is used in the response, any response message
+ # can be used which meets this criteria for the test of having a nested field.
+ response_type = http_pb2.HttpRule
+ responses = [
+ http_pb2.HttpRule(
+ selector="some selector",
+ custom=http_pb2.CustomHttpPattern(kind="some kind"),
+ ),
+ http_pb2.HttpRule(
+ selector="another selector",
+ custom=http_pb2.CustomHttpPattern(path="some path"),
+ ),
+ ]
resp = ResponseMock(
- responses=responses, random_split=random_split, response_cls=Song
+ responses=responses, random_split=random_split, response_cls=response_type
)
- itr = rest_streaming.ResponseIterator(resp, Song)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses
[email protected]("random_split", [True, False])
-def test_next_stress(random_split):
[email protected](
+ "random_split,resp_message_is_proto_plus",
+ [
+ (True, True),
+ (False, True),
+ (True, False),
+ (False, False),
+ ],
+)
+def test_next_stress(random_split, resp_message_is_proto_plus):
n = 50
- responses = [
- Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
- for i in range(n)
- ]
+ if resp_message_is_proto_plus:
+ response_type = Song
+ responses = [
+ Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
+ for i in range(n)
+ ]
+ else:
+ response_type = http_pb2.HttpRule
+ responses = [
+ http_pb2.HttpRule(
+ selector="selector_%d" % i,
+ custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
+ )
+ for i in range(n)
+ ]
resp = ResponseMock(
- responses=responses, random_split=random_split, response_cls=Song
+ responses=responses, random_split=random_split, response_cls=response_type
)
- itr = rest_streaming.ResponseIterator(resp, Song)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses
[email protected]("random_split", [True, False])
-def test_next_escaped_characters_in_string(random_split):
- composer_with_relateds = Composer()
- relateds = ["Artist A", "Artist B"]
- composer_with_relateds.relateds = relateds
[email protected](
+ "random_split,resp_message_is_proto_plus",
+ [
+ (True, True),
+ (False, True),
+ (True, False),
+ (False, False),
+ ],
+)
+def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus):
+ if resp_message_is_proto_plus:
+ response_type = Song
+ composer_with_relateds = Composer()
+ relateds = ["Artist A", "Artist B"]
+ composer_with_relateds.relateds = relateds
- responses = [
- Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
- Song(
- title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
- ),
- Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
- ]
+ responses = [
+ Song(
+ title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
+ ),
+ Song(
+ title='{"this is weird": "totally"}',
+ composer=Composer(given_name="\\{}\\"),
+ ),
+ Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
+ ]
+ else:
+ response_type = http_pb2.Http
+ responses = [
+ http_pb2.Http(
+ rules=[
+ http_pb2.HttpRule(
+ selector='ti"tle\nfoo\tbar{}',
+ custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
+ )
+ ]
+ ),
+ http_pb2.Http(
+ rules=[
+ http_pb2.HttpRule(
+ selector='{"this is weird": "totally"}',
+ custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
+ )
+ ]
+ ),
+ http_pb2.Http(
+ rules=[
+ http_pb2.HttpRule(
+ selector='\\{"key": ["value",]}\\',
+ custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
+ )
+ ]
+ ),
+ ]
resp = ResponseMock(
- responses=responses, random_split=random_split, response_cls=Song
+ responses=responses, random_split=random_split, response_cls=response_type
)
- itr = rest_streaming.ResponseIterator(resp, Song)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses
-def test_next_not_array():
[email protected]("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+def test_next_not_array(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
) as mock_method:
-
- resp = ResponseMock(responses=[], response_cls=EchoResponse)
- itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()
-def test_cancel():
[email protected]("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+def test_cancel(response_type):
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
- resp = ResponseMock(responses=[], response_cls=EchoResponse)
- itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
itr.cancel()
mock_method.assert_called_once()
-def test_check_buffer():
[email protected](
+ "response_type,return_value",
+ [
+ (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
+ (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
+ ],
+)
+def test_check_buffer(response_type, return_value):
with patch.object(
ResponseMock,
"_parse_responses",
- return_value=bytes('[{"content": "hello"}, {', "utf-8"),
+ return_value=return_value,
):
- resp = ResponseMock(responses=[], response_cls=EchoResponse)
- itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
next(itr)
-def test_next_html():
[email protected]("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+def test_next_html(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
) as mock_method:
-
- resp = ResponseMock(responses=[], response_cls=EchoResponse)
- itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()
+
+
+def test_invalid_response_class():
+ class SomeClass:
+ pass
+
+ resp = ResponseMock(responses=[], response_cls=SomeClass)
+ response_iterator = rest_streaming.ResponseIterator(resp, SomeClass)
+ with pytest.raises(
+ ValueError,
+ match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
+ ):
+ response_iterator._grab()