| # Copyright 2016 gRPC authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Entry point for running stress tests.""" |
| |
| from concurrent import futures |
| import queue |
| import threading |
| |
| from absl import app |
| from absl.flags import argparse_flags |
| import grpc |
| |
| from src.proto.grpc.testing import metrics_pb2_grpc |
| from src.proto.grpc.testing import test_pb2_grpc |
| from tests.interop import methods |
| from tests.interop import resources |
| from tests.qps import histogram |
| from tests.stress import metrics_server |
| from tests.stress import test_runner |
| |
| |
| def _args(argv): |
| parser = argparse_flags.ArgumentParser() |
| parser.add_argument( |
| "--server_addresses", |
| help="comma separated list of hostname:port to run servers on", |
| default="localhost:8080", |
| type=str, |
| ) |
| parser.add_argument( |
| "--test_cases", |
| help="comma separated list of testcase:weighting of tests to run", |
| default="large_unary:100", |
| type=str, |
| ) |
| parser.add_argument( |
| "--test_duration_secs", |
| help="number of seconds to run the stress test", |
| default=-1, |
| type=int, |
| ) |
| parser.add_argument( |
| "--num_channels_per_server", |
| help="number of channels per server", |
| default=1, |
| type=int, |
| ) |
| parser.add_argument( |
| "--num_stubs_per_channel", |
| help="number of stubs to create per channel", |
| default=1, |
| type=int, |
| ) |
| parser.add_argument( |
| "--metrics_port", |
| help="the port to listen for metrics requests on", |
| default=8081, |
| type=int, |
| ) |
| parser.add_argument( |
| "--use_test_ca", |
| help="Whether to use our fake CA. Requires --use_tls=true", |
| default=False, |
| type=bool, |
| ) |
| parser.add_argument( |
| "--use_tls", help="Whether to use TLS", default=False, type=bool |
| ) |
| parser.add_argument( |
| "--server_host_override", |
| help="the server host to which to claim to connect", |
| type=str, |
| ) |
| return parser.parse_args(argv[1:]) |
| |
| |
| def _test_case_from_arg(test_case_arg): |
| for test_case in methods.TestCase: |
| if test_case_arg == test_case.value: |
| return test_case |
| else: |
| raise ValueError("No test case {}!".format(test_case_arg)) |
| |
| |
| def _parse_weighted_test_cases(test_case_args): |
| weighted_test_cases = {} |
| for test_case_arg in test_case_args.split(","): |
| name, weight = test_case_arg.split(":", 1) |
| test_case = _test_case_from_arg(name) |
| weighted_test_cases[test_case] = int(weight) |
| return weighted_test_cases |
| |
| |
| def _get_channel(target, args): |
| if args.use_tls: |
| if args.use_test_ca: |
| root_certificates = resources.test_root_certificates() |
| else: |
| root_certificates = None # will load default roots. |
| channel_credentials = grpc.ssl_channel_credentials( |
| root_certificates=root_certificates |
| ) |
| options = ( |
| ( |
| "grpc.ssl_target_name_override", |
| args.server_host_override, |
| ), |
| ) |
| channel = grpc.secure_channel( |
| target, channel_credentials, options=options |
| ) |
| else: |
| channel = grpc.insecure_channel(target) |
| |
| # waits for the channel to be ready before we start sending messages |
| grpc.channel_ready_future(channel).result() |
| return channel |
| |
| |
| def run_test(args): |
| test_cases = _parse_weighted_test_cases(args.test_cases) |
| test_server_targets = args.server_addresses.split(",") |
| # Propagate any client exceptions with a queue |
| exception_queue = queue.Queue() |
| stop_event = threading.Event() |
| hist = histogram.Histogram(1, 1) |
| runners = [] |
| |
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=25)) |
| metrics_pb2_grpc.add_MetricsServiceServicer_to_server( |
| metrics_server.MetricsServer(hist), server |
| ) |
| server.add_insecure_port("[::]:{}".format(args.metrics_port)) |
| server.start() |
| |
| for test_server_target in test_server_targets: |
| for _ in range(args.num_channels_per_server): |
| channel = _get_channel(test_server_target, args) |
| for _ in range(args.num_stubs_per_channel): |
| stub = test_pb2_grpc.TestServiceStub(channel) |
| runner = test_runner.TestRunner( |
| stub, test_cases, hist, exception_queue, stop_event |
| ) |
| runners.append(runner) |
| |
| for runner in runners: |
| runner.start() |
| try: |
| timeout_secs = args.test_duration_secs |
| if timeout_secs < 0: |
| timeout_secs = None |
| raise exception_queue.get(block=True, timeout=timeout_secs) |
| except queue.Empty: |
| # No exceptions thrown, success |
| pass |
| finally: |
| stop_event.set() |
| for runner in runners: |
| runner.join() |
| runner = None |
| server.stop(None) |
| |
| |
| if __name__ == "__main__": |
| app.run(run_test, flags_parser=_args) |