| #!/usr/bin/env python3 |
| # Copyright 2015 gRPC authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Starts a local DNS server for use in tests""" |
| |
| import argparse |
| import os |
| import platform |
| import signal |
| import sys |
| import threading |
| import time |
| |
| import twisted |
| import twisted.internet |
| import twisted.internet.defer |
| import twisted.internet.protocol |
| import twisted.internet.reactor |
| import twisted.internet.threads |
| import twisted.names |
| from twisted.names import authority |
| from twisted.names import client |
| from twisted.names import common |
| from twisted.names import dns |
| from twisted.names import server |
| import twisted.names.client |
| import twisted.names.dns |
| import twisted.names.server |
| import yaml |
| |
| _SERVER_HEALTH_CHECK_RECORD_NAME = ( # missing end '.' for twisted syntax |
| "health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp" |
| ) |
| _SERVER_HEALTH_CHECK_RECORD_DATA = "123.123.123.123" |
| |
| |
| class NoFileAuthority(authority.FileAuthority): |
| def __init__(self, soa, records): |
| # skip FileAuthority |
| common.ResolverBase.__init__(self) |
| self.soa = soa |
| self.records = records |
| |
| |
| def start_local_dns_server(args): |
| all_records = {} |
| |
| def _push_record(name, r): |
| name = name.encode("ascii") |
| print("pushing record: |%s|" % name) |
| if all_records.get(name) is not None: |
| all_records[name].append(r) |
| return |
| all_records[name] = [r] |
| |
| def _maybe_split_up_txt_data(name, txt_data, r_ttl): |
| txt_data = txt_data.encode("ascii") |
| start = 0 |
| txt_data_list = [] |
| while len(txt_data[start:]) > 0: |
| next_read = len(txt_data[start:]) |
| if next_read > 255: |
| next_read = 255 |
| txt_data_list.append(txt_data[start : start + next_read]) |
| start += next_read |
| _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl)) |
| |
| with open(args.records_config_path) as config: |
| test_records_config = yaml.safe_load(config) |
| common_zone_name = test_records_config["resolver_tests_common_zone_name"] |
| for group in test_records_config["resolver_component_tests"]: |
| for name in group["records"].keys(): |
| for record in group["records"][name]: |
| r_type = record["type"] |
| r_data = record["data"] |
| r_ttl = int(record["TTL"]) |
| record_full_name = "%s.%s" % (name, common_zone_name) |
| assert record_full_name[-1] == "." |
| record_full_name = record_full_name[:-1] |
| if r_type == "A": |
| _push_record( |
| record_full_name, dns.Record_A(r_data, ttl=r_ttl) |
| ) |
| if r_type == "AAAA": |
| _push_record( |
| record_full_name, dns.Record_AAAA(r_data, ttl=r_ttl) |
| ) |
| if r_type == "SRV": |
| p, w, port, target = r_data.split(" ") |
| p = int(p) |
| w = int(w) |
| port = int(port) |
| target_full_name = ( |
| "%s.%s" % (target, common_zone_name) |
| ).encode("ascii") |
| _push_record( |
| record_full_name, |
| dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl), |
| ) |
| if r_type == "TXT": |
| _maybe_split_up_txt_data(record_full_name, r_data, r_ttl) |
| # Add an optional IPv4 record is specified |
| if args.add_a_record: |
| extra_host, extra_host_ipv4 = args.add_a_record.split(":") |
| _push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0)) |
| # Server health check record |
| _push_record( |
| _SERVER_HEALTH_CHECK_RECORD_NAME, |
| dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0), |
| ) |
| soa_record = dns.Record_SOA(mname=common_zone_name.encode("ascii")) |
| test_domain_com = NoFileAuthority( |
| soa=(common_zone_name.encode("ascii"), soa_record), |
| records=all_records, |
| ) |
| server = twisted.names.server.DNSServerFactory( |
| authorities=[test_domain_com], verbose=2 |
| ) |
| server.noisy = 2 |
| twisted.internet.reactor.listenTCP(args.port, server) |
| dns_proto = twisted.names.dns.DNSDatagramProtocol(server) |
| dns_proto.noisy = 2 |
| twisted.internet.reactor.listenUDP(args.port, dns_proto) |
| print("starting local dns server on 127.0.0.1:%s" % args.port) |
| print("starting twisted.internet.reactor") |
| twisted.internet.reactor.suggestThreadPoolSize(1) |
| twisted.internet.reactor.run() |
| |
| |
| def _quit_on_signal(signum, _frame): |
| print("Received SIGNAL %d. Quitting with exit code 0" % signum) |
| twisted.internet.reactor.stop() |
| sys.stdout.flush() |
| sys.exit(0) |
| |
| |
| def flush_stdout_loop(): |
| num_timeouts_so_far = 0 |
| sleep_time = 1 |
| # Prevent zombies. Tests that use this server are short-lived. |
| max_timeouts = 60 * 10 |
| while num_timeouts_so_far < max_timeouts: |
| sys.stdout.flush() |
| time.sleep(sleep_time) |
| num_timeouts_so_far += 1 |
| print("Process timeout reached, or cancelled. Exitting 0.") |
| os.kill(os.getpid(), signal.SIGTERM) |
| |
| |
| def main(): |
| argp = argparse.ArgumentParser( |
| description="Local DNS Server for resolver tests" |
| ) |
| argp.add_argument( |
| "-p", |
| "--port", |
| default=None, |
| type=int, |
| help="Port for DNS server to listen on for TCP and UDP.", |
| ) |
| argp.add_argument( |
| "-r", |
| "--records_config_path", |
| default=None, |
| type=str, |
| help=( |
| "Directory of resolver_test_record_groups.yaml file. " |
| "Defaults to path needed when the test is invoked as part " |
| "of run_tests.py." |
| ), |
| ) |
| argp.add_argument( |
| "--add_a_record", |
| default=None, |
| type=str, |
| help=( |
| "Add an A record via the command line. Useful for when we " |
| "need to serve a one-off A record that is under a " |
| "different domain then the rest the records configured in " |
| "--records_config_path (which all need to be under the " |
| "same domain). Format: <name>:<ipv4 address>" |
| ), |
| ) |
| args = argp.parse_args() |
| signal.signal(signal.SIGTERM, _quit_on_signal) |
| signal.signal(signal.SIGINT, _quit_on_signal) |
| output_flush_thread = threading.Thread(target=flush_stdout_loop) |
| output_flush_thread.setDaemon(True) |
| output_flush_thread.start() |
| start_local_dns_server(args) |
| |
| |
| if __name__ == "__main__": |
| main() |