fix: resolve issue handling protobuf responses in rest streaming (#604)

* fix: resolve issue handling protobuf responses in rest streaming

* raise ValueError if response_message_cls is not a subclass of proto.Message or google.protobuf.message.Message

* remove response_type from pytest.mark.parametrize

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* add test for ValueError in response_iterator._grab()

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py
index f91381c..3f5b6b0 100644
--- a/google/api_core/rest_streaming.py
+++ b/google/api_core/rest_streaming.py
@@ -16,9 +16,12 @@
 
 from collections import deque
 import string
-from typing import Deque
+from typing import Deque, Union
 
+import proto
 import requests
+import google.protobuf.message
+from google.protobuf.json_format import Parse
 
 
 class ResponseIterator:
@@ -26,11 +29,18 @@
 
     Args:
         response (requests.Response): An API response object.
-        response_message_cls (Callable[proto.Message]): A proto
+        response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
         class expected to be returned from an API.
+
+    Raises:
+        ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
     """
 
-    def __init__(self, response: requests.Response, response_message_cls):
+    def __init__(
+        self,
+        response: requests.Response,
+        response_message_cls: Union[proto.Message, google.protobuf.message.Message],
+    ):
         self._response = response
         self._response_message_cls = response_message_cls
         # Inner iterator over HTTP response's content.
@@ -107,7 +117,14 @@
 
     def _grab(self):
         # Add extra quotes to make json.loads happy.
-        return self._response_message_cls.from_json(self._ready_objs.popleft())
+        if issubclass(self._response_message_cls, proto.Message):
+            return self._response_message_cls.from_json(self._ready_objs.popleft())
+        elif issubclass(self._response_message_cls, google.protobuf.message.Message):
+            return Parse(self._ready_objs.popleft(), self._response_message_cls())
+        else:
+            raise ValueError(
+                "Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
+            )
 
     def __iter__(self):
         return self
diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py
index a44c83c..b532eb1 100644
--- a/tests/unit/test_rest_streaming.py
+++ b/tests/unit/test_rest_streaming.py
@@ -24,8 +24,11 @@
 import requests
 
 from google.api_core import rest_streaming
+from google.api import http_pb2
+from google.api import httpbody_pb2
 from google.protobuf import duration_pb2
 from google.protobuf import timestamp_pb2
+from google.protobuf.json_format import MessageToJson
 
 
 __protobuf__ = proto.module(package=__name__)
@@ -98,7 +101,10 @@
         # json.dumps returns a string surrounded with quotes that need to be stripped
         # in order to be an actual JSON.
         json_responses = [
-            self._response_message_cls.to_json(r).strip('"') for r in responses
+            self._response_message_cls.to_json(r).strip('"')
+            if issubclass(self._response_message_cls, proto.Message)
+            else MessageToJson(r).strip('"')
+            for r in responses
         ]
         logging.info(f"Sending JSON stream: {json_responses}")
         ret_val = "[{}]".format(",".join(json_responses))
@@ -114,103 +120,220 @@
         )
 
 
[email protected]("random_split", [False])
-def test_next_simple(random_split):
-    responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
[email protected](
+    "random_split,resp_message_is_proto_plus",
+    [(False, True), (False, False)],
+)
+def test_next_simple(random_split, resp_message_is_proto_plus):
+    if resp_message_is_proto_plus:
+        response_type = EchoResponse
+        responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
+    else:
+        response_type = httpbody_pb2.HttpBody
+        responses = [
+            httpbody_pb2.HttpBody(content_type="hello world"),
+            httpbody_pb2.HttpBody(content_type="yes"),
+        ]
+
     resp = ResponseMock(
-        responses=responses, random_split=random_split, response_cls=EchoResponse
+        responses=responses, random_split=random_split, response_cls=response_type
     )
-    itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+    itr = rest_streaming.ResponseIterator(resp, response_type)
     assert list(itr) == responses
 
 
[email protected]("random_split", [True, False])
-def test_next_nested(random_split):
-    responses = [
-        Song(title="some song", composer=Composer(given_name="some name")),
-        Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
-    ]
[email protected](
+    "random_split,resp_message_is_proto_plus",
+    [
+        (True, True),
+        (False, True),
+        (True, False),
+        (False, False),
+    ],
+)
+def test_next_nested(random_split, resp_message_is_proto_plus):
+    if resp_message_is_proto_plus:
+        response_type = Song
+        responses = [
+            Song(title="some song", composer=Composer(given_name="some name")),
+            Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
+        ]
+    else:
+        # Although `http_pb2.HttpRule`` is used in the response, any response message
+        # can be used which meets this criteria for the test of having a nested field.
+        response_type = http_pb2.HttpRule
+        responses = [
+            http_pb2.HttpRule(
+                selector="some selector",
+                custom=http_pb2.CustomHttpPattern(kind="some kind"),
+            ),
+            http_pb2.HttpRule(
+                selector="another selector",
+                custom=http_pb2.CustomHttpPattern(path="some path"),
+            ),
+        ]
     resp = ResponseMock(
-        responses=responses, random_split=random_split, response_cls=Song
+        responses=responses, random_split=random_split, response_cls=response_type
     )
-    itr = rest_streaming.ResponseIterator(resp, Song)
+    itr = rest_streaming.ResponseIterator(resp, response_type)
     assert list(itr) == responses
 
 
[email protected]("random_split", [True, False])
-def test_next_stress(random_split):
[email protected](
+    "random_split,resp_message_is_proto_plus",
+    [
+        (True, True),
+        (False, True),
+        (True, False),
+        (False, False),
+    ],
+)
+def test_next_stress(random_split, resp_message_is_proto_plus):
     n = 50
-    responses = [
-        Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
-        for i in range(n)
-    ]
+    if resp_message_is_proto_plus:
+        response_type = Song
+        responses = [
+            Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
+            for i in range(n)
+        ]
+    else:
+        response_type = http_pb2.HttpRule
+        responses = [
+            http_pb2.HttpRule(
+                selector="selector_%d" % i,
+                custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
+            )
+            for i in range(n)
+        ]
     resp = ResponseMock(
-        responses=responses, random_split=random_split, response_cls=Song
+        responses=responses, random_split=random_split, response_cls=response_type
     )
-    itr = rest_streaming.ResponseIterator(resp, Song)
+    itr = rest_streaming.ResponseIterator(resp, response_type)
     assert list(itr) == responses
 
 
[email protected]("random_split", [True, False])
-def test_next_escaped_characters_in_string(random_split):
-    composer_with_relateds = Composer()
-    relateds = ["Artist A", "Artist B"]
-    composer_with_relateds.relateds = relateds
[email protected](
+    "random_split,resp_message_is_proto_plus",
+    [
+        (True, True),
+        (False, True),
+        (True, False),
+        (False, False),
+    ],
+)
+def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus):
+    if resp_message_is_proto_plus:
+        response_type = Song
+        composer_with_relateds = Composer()
+        relateds = ["Artist A", "Artist B"]
+        composer_with_relateds.relateds = relateds
 
-    responses = [
-        Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
-        Song(
-            title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
-        ),
-        Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
-    ]
+        responses = [
+            Song(
+                title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
+            ),
+            Song(
+                title='{"this is weird": "totally"}',
+                composer=Composer(given_name="\\{}\\"),
+            ),
+            Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
+        ]
+    else:
+        response_type = http_pb2.Http
+        responses = [
+            http_pb2.Http(
+                rules=[
+                    http_pb2.HttpRule(
+                        selector='ti"tle\nfoo\tbar{}',
+                        custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
+                    )
+                ]
+            ),
+            http_pb2.Http(
+                rules=[
+                    http_pb2.HttpRule(
+                        selector='{"this is weird": "totally"}',
+                        custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
+                    )
+                ]
+            ),
+            http_pb2.Http(
+                rules=[
+                    http_pb2.HttpRule(
+                        selector='\\{"key": ["value",]}\\',
+                        custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
+                    )
+                ]
+            ),
+        ]
     resp = ResponseMock(
-        responses=responses, random_split=random_split, response_cls=Song
+        responses=responses, random_split=random_split, response_cls=response_type
     )
-    itr = rest_streaming.ResponseIterator(resp, Song)
+    itr = rest_streaming.ResponseIterator(resp, response_type)
     assert list(itr) == responses
 
 
-def test_next_not_array():
[email protected]("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+def test_next_not_array(response_type):
     with patch.object(
         ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
     ) as mock_method:
-
-        resp = ResponseMock(responses=[], response_cls=EchoResponse)
-        itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+        resp = ResponseMock(responses=[], response_cls=response_type)
+        itr = rest_streaming.ResponseIterator(resp, response_type)
         with pytest.raises(ValueError):
             next(itr)
         mock_method.assert_called_once()
 
 
-def test_cancel():
[email protected]("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+def test_cancel(response_type):
     with patch.object(ResponseMock, "close", return_value=None) as mock_method:
-        resp = ResponseMock(responses=[], response_cls=EchoResponse)
-        itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+        resp = ResponseMock(responses=[], response_cls=response_type)
+        itr = rest_streaming.ResponseIterator(resp, response_type)
         itr.cancel()
         mock_method.assert_called_once()
 
 
-def test_check_buffer():
[email protected](
+    "response_type,return_value",
+    [
+        (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
+        (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
+    ],
+)
+def test_check_buffer(response_type, return_value):
     with patch.object(
         ResponseMock,
         "_parse_responses",
-        return_value=bytes('[{"content": "hello"}, {', "utf-8"),
+        return_value=return_value,
     ):
-        resp = ResponseMock(responses=[], response_cls=EchoResponse)
-        itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+        resp = ResponseMock(responses=[], response_cls=response_type)
+        itr = rest_streaming.ResponseIterator(resp, response_type)
         with pytest.raises(ValueError):
             next(itr)
             next(itr)
 
 
-def test_next_html():
[email protected]("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+def test_next_html(response_type):
     with patch.object(
         ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
     ) as mock_method:
-
-        resp = ResponseMock(responses=[], response_cls=EchoResponse)
-        itr = rest_streaming.ResponseIterator(resp, EchoResponse)
+        resp = ResponseMock(responses=[], response_cls=response_type)
+        itr = rest_streaming.ResponseIterator(resp, response_type)
         with pytest.raises(ValueError):
             next(itr)
         mock_method.assert_called_once()
+
+
+def test_invalid_response_class():
+    class SomeClass:
+        pass
+
+    resp = ResponseMock(responses=[], response_cls=SomeClass)
+    response_iterator = rest_streaming.ResponseIterator(resp, SomeClass)
+    with pytest.raises(
+        ValueError,
+        match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
+    ):
+        response_iterator._grab()