blob: b9fc3ccf0f600db2d864024cdf7e87e3effb2d44 [file] [log] [blame]
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import importlib
import logging
import os
import pkgutil
import re
import sys
import unittest
import coverage
logger = logging.getLogger(__name__)
TEST_MODULE_REGEX = r"^.*_test$"
# Determines the path og a given path relative to the first matching
# path on sys.path. Useful for determining what a directory's module
# path will be.
def _relativize_to_sys_path(path):
for sys_path in sys.path:
if path.startswith(sys_path):
relative = path[len(sys_path) :]
if not relative:
return ""
if relative.startswith(os.path.sep):
relative = relative[len(os.path.sep) :]
if not relative.endswith(os.path.sep):
relative += os.path.sep
return relative
raise AssertionError("Failed to relativize {} to sys.path.".format(path))
def _relative_path_to_module_prefix(path):
return path.replace(os.path.sep, ".")
class Loader(object):
"""Test loader for setuptools test suite support.
Attributes:
suite (unittest.TestSuite): All tests collected by the loader.
loader (unittest.TestLoader): Standard Python unittest loader to be ran per
module discovered.
module_matcher (re.RegexObject): A regular expression object to match
against module names and determine whether or not the discovered module
contributes to the test suite.
"""
def __init__(self):
self.suite = unittest.TestSuite()
self.loader = unittest.TestLoader()
self.module_matcher = re.compile(TEST_MODULE_REGEX)
def loadTestsFromNames(self, names, module=None):
"""Function mirroring TestLoader::loadTestsFromNames, as expected by
setuptools.setup argument `test_loader`."""
# ensure that we capture decorators and definitions (else our coverage
# measure unnecessarily suffers)
coverage_context = coverage.Coverage(data_suffix=True)
coverage_context.start()
imported_modules = tuple(
importlib.import_module(name) for name in names
)
for imported_module in imported_modules:
self.visit_module(imported_module)
for imported_module in imported_modules:
try:
package_paths = imported_module.__path__
except AttributeError:
continue
self.walk_packages(package_paths)
coverage_context.stop()
coverage_context.save()
return self.suite
def walk_packages(self, package_paths):
"""Walks over the packages, dispatching `visit_module` calls.
Args:
package_paths (list): A list of paths over which to walk through modules
along.
"""
for path in package_paths:
self._walk_package(path)
def _walk_package(self, package_path):
prefix = _relative_path_to_module_prefix(
_relativize_to_sys_path(package_path)
)
for importer, module_name, is_package in pkgutil.walk_packages(
[package_path], prefix
):
module = None
if module_name in sys.modules:
module = sys.modules[module_name]
self.visit_module(module)
else:
try:
spec = importer.find_spec(module_name)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.visit_module(module)
except ModuleNotFoundError:
logger.debug("Skip loading %s", module_name)
def visit_module(self, module):
"""Visits the module, adding discovered tests to the test suite.
Args:
module (module): Module to match against self.module_matcher; if matched
it has its tests loaded via self.loader into self.suite.
"""
if self.module_matcher.match(module.__name__):
module_suite = self.loader.loadTestsFromModule(module)
self.suite.addTest(module_suite)
def iterate_suite_cases(suite):
"""Generator over all unittest.TestCases in a unittest.TestSuite.
Args:
suite (unittest.TestSuite): Suite to iterate over in the generator.
Returns:
generator: A generator over all unittest.TestCases in `suite`.
"""
for item in suite:
if isinstance(item, unittest.TestSuite):
for child_item in iterate_suite_cases(item):
yield child_item
elif isinstance(item, unittest.TestCase):
yield item
else:
raise ValueError(
"unexpected suite item of type {}".format(type(item))
)