| # Lint as: python2, python3 | 
 | # Copyright 2008 Google Inc, Martin J. Bligh <[email protected]>, | 
 | #                Benjamin Poirier, Ryan Stutsman | 
 | # Released under the GPL v2 | 
 | """ | 
 | Miscellaneous small functions. | 
 |  | 
 | DO NOT import this file directly - it is mixed in by server/utils.py, | 
 | import that instead | 
 | """ | 
 |  | 
 | from __future__ import absolute_import | 
 | from __future__ import division | 
 | from __future__ import print_function | 
 |  | 
 | import atexit, os, re, shutil, textwrap, sys, tempfile, types | 
 | import six | 
 |  | 
 | from autotest_lib.client.common_lib import barrier, utils | 
 | from autotest_lib.server import subcommand | 
 |  | 
 |  | 
 | # A dictionary of pid and a list of tmpdirs for that pid | 
 | __tmp_dirs = {} | 
 |  | 
 |  | 
 | def scp_remote_escape(filename): | 
 |     """ | 
 |     Escape special characters from a filename so that it can be passed | 
 |     to scp (within double quotes) as a remote file. | 
 |  | 
 |     Bis-quoting has to be used with scp for remote files, "bis-quoting" | 
 |     as in quoting x 2 | 
 |     scp does not support a newline in the filename | 
 |  | 
 |     Args: | 
 |             filename: the filename string to escape. | 
 |  | 
 |     Returns: | 
 |             The escaped filename string. The required englobing double | 
 |             quotes are NOT added and so should be added at some point by | 
 |             the caller. | 
 |     """ | 
 |     escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}' | 
 |  | 
 |     new_name= [] | 
 |     for char in filename: | 
 |         if char in escape_chars: | 
 |             new_name.append("\\%s" % (char,)) | 
 |         else: | 
 |             new_name.append(char) | 
 |  | 
 |     return utils.sh_escape("".join(new_name)) | 
 |  | 
 |  | 
 | def get(location, local_copy = False): | 
 |     """Get a file or directory to a local temporary directory. | 
 |  | 
 |     Args: | 
 |             location: the source of the material to get. This source may | 
 |                     be one of: | 
 |                     * a local file or directory | 
 |                     * a URL (http or ftp) | 
 |                     * a python file-like object | 
 |  | 
 |     Returns: | 
 |             The location of the file or directory where the requested | 
 |             content was saved. This will be contained in a temporary | 
 |             directory on the local host. If the material to get was a | 
 |             directory, the location will contain a trailing '/' | 
 |     """ | 
 |     tmpdir = get_tmp_dir() | 
 |  | 
 |     # location is a file-like object | 
 |     if hasattr(location, "read"): | 
 |         tmpfile = os.path.join(tmpdir, "file") | 
 |         tmpfileobj = open(tmpfile, 'w') | 
 |         shutil.copyfileobj(location, tmpfileobj) | 
 |         tmpfileobj.close() | 
 |         return tmpfile | 
 |  | 
 |     if isinstance(location, six.string_types): | 
 |         # location is a URL | 
 |         if location.startswith('http') or location.startswith('ftp'): | 
 |             tmpfile = os.path.join(tmpdir, os.path.basename(location)) | 
 |             utils.urlretrieve(location, tmpfile) | 
 |             return tmpfile | 
 |         # location is a local path | 
 |         elif os.path.exists(os.path.abspath(location)): | 
 |             if not local_copy: | 
 |                 if os.path.isdir(location): | 
 |                     return location.rstrip('/') + '/' | 
 |                 else: | 
 |                     return location | 
 |             tmpfile = os.path.join(tmpdir, os.path.basename(location)) | 
 |             if os.path.isdir(location): | 
 |                 tmpfile += '/' | 
 |                 shutil.copytree(location, tmpfile, symlinks=True) | 
 |                 return tmpfile | 
 |             shutil.copyfile(location, tmpfile) | 
 |             return tmpfile | 
 |         # location is just a string, dump it to a file | 
 |         else: | 
 |             tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir) | 
 |             tmpfileobj = os.fdopen(tmpfd, 'w') | 
 |             tmpfileobj.write(location) | 
 |             tmpfileobj.close() | 
 |             return tmpfile | 
 |  | 
 |  | 
 | def get_tmp_dir(): | 
 |     """Return the pathname of a directory on the host suitable | 
 |     for temporary file storage. | 
 |  | 
 |     The directory and its content will be deleted automatically | 
 |     at the end of the program execution if they are still present. | 
 |     """ | 
 |     dir_name = tempfile.mkdtemp(prefix="autoserv-") | 
 |     pid = os.getpid() | 
 |     if not pid in __tmp_dirs: | 
 |         __tmp_dirs[pid] = [] | 
 |     __tmp_dirs[pid].append(dir_name) | 
 |     return dir_name | 
 |  | 
 |  | 
 | def __clean_tmp_dirs(): | 
 |     """Erase temporary directories that were created by the get_tmp_dir() | 
 |     function and that are still present. | 
 |     """ | 
 |     pid = os.getpid() | 
 |     if pid not in __tmp_dirs: | 
 |         return | 
 |     for dir in __tmp_dirs[pid]: | 
 |         try: | 
 |             shutil.rmtree(dir) | 
 |         except OSError as e: | 
 |             if e.errno == 2: | 
 |                 pass | 
 |     __tmp_dirs[pid] = [] | 
 | atexit.register(__clean_tmp_dirs) | 
 | subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs()) | 
 |  | 
 |  | 
 | def unarchive(host, source_material): | 
 |     """Uncompress and untar an archive on a host. | 
 |  | 
 |     If the "source_material" is compresses (according to the file | 
 |     extension) it will be uncompressed. Supported compression formats | 
 |     are gzip and bzip2. Afterwards, if the source_material is a tar | 
 |     archive, it will be untarred. | 
 |  | 
 |     Args: | 
 |             host: the host object on which the archive is located | 
 |             source_material: the path of the archive on the host | 
 |  | 
 |     Returns: | 
 |             The file or directory name of the unarchived source material. | 
 |             If the material is a tar archive, it will be extracted in the | 
 |             directory where it is and the path returned will be the first | 
 |             entry in the archive, assuming it is the topmost directory. | 
 |             If the material is not an archive, nothing will be done so this | 
 |             function is "harmless" when it is "useless". | 
 |     """ | 
 |     # uncompress | 
 |     if (source_material.endswith(".gz") or | 
 |             source_material.endswith(".gzip")): | 
 |         host.run('gunzip "%s"' % (utils.sh_escape(source_material))) | 
 |         source_material= ".".join(source_material.split(".")[:-1]) | 
 |     elif source_material.endswith("bz2"): | 
 |         host.run('bunzip2 "%s"' % (utils.sh_escape(source_material))) | 
 |         source_material= ".".join(source_material.split(".")[:-1]) | 
 |  | 
 |     # untar | 
 |     if source_material.endswith(".tar"): | 
 |         retval= host.run('tar -C "%s" -xvf "%s"' % ( | 
 |                 utils.sh_escape(os.path.dirname(source_material)), | 
 |                 utils.sh_escape(source_material),)) | 
 |         source_material= os.path.join(os.path.dirname(source_material), | 
 |                 retval.stdout.split()[0]) | 
 |  | 
 |     return source_material | 
 |  | 
 |  | 
 | def get_server_dir(): | 
 |     path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__) | 
 |     return os.path.abspath(path) | 
 |  | 
 |  | 
 | def find_pid(command): | 
 |     for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'): | 
 |         (pid, cmd) = line.split(None, 1) | 
 |         if re.search(command, cmd): | 
 |             return int(pid) | 
 |     return None | 
 |  | 
 |  | 
 | def default_mappings(machines): | 
 |     """ | 
 |     Returns a simple mapping in which all machines are assigned to the | 
 |     same key.  Provides the default behavior for | 
 |     form_ntuples_from_machines. """ | 
 |     mappings = {} | 
 |     failures = [] | 
 |  | 
 |     mach = machines[0] | 
 |     mappings['ident'] = [mach] | 
 |     if len(machines) > 1: | 
 |         machines = machines[1:] | 
 |         for machine in machines: | 
 |             mappings['ident'].append(machine) | 
 |  | 
 |     return (mappings, failures) | 
 |  | 
 |  | 
 | def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings): | 
 |     """Returns a set of ntuples from machines where the machines in an | 
 |        ntuple are in the same mapping, and a set of failures which are | 
 |        (machine name, reason) tuples.""" | 
 |     ntuples = [] | 
 |     (mappings, failures) = mapping_func(machines) | 
 |  | 
 |     # now run through the mappings and create n-tuples. | 
 |     # throw out the odd guys out | 
 |     for key in mappings: | 
 |         key_machines = mappings[key] | 
 |         total_machines = len(key_machines) | 
 |  | 
 |         # form n-tuples | 
 |         while len(key_machines) >= n: | 
 |             ntuples.append(key_machines[0:n]) | 
 |             key_machines = key_machines[n:] | 
 |  | 
 |         for mach in key_machines: | 
 |             failures.append((mach, "machine can not be tupled")) | 
 |  | 
 |     return (ntuples, failures) | 
 |  | 
 |  | 
 | def parse_machine(machine, user='root', password='', port=22): | 
 |     """ | 
 |     Parse the machine string user:pass@host:port and return it separately, | 
 |     if the machine string is not complete, use the default parameters | 
 |     when appropriate. | 
 |     """ | 
 |  | 
 |     if '@' in machine: | 
 |         user, machine = machine.split('@', 1) | 
 |  | 
 |     if ':' in user: | 
 |         user, password = user.split(':', 1) | 
 |  | 
 |     # Brackets are required to protect an IPv6 address whenever a | 
 |     # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to | 
 |     # it. Do not attempt to extract a (non-existent) port number from | 
 |     # an unprotected/bare IPv6 address "xx::xx". | 
 |     # In the Python >= 3.3 future, 'import ipaddress' will parse | 
 |     # addresses; and maybe more. | 
 |     bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine) | 
 |  | 
 |     # Extract trailing :port number if any. | 
 |     if not bare_ipv6 and re.search(r':\d*$', machine): | 
 |         machine, port = machine.rsplit(':', 1) | 
 |         port = int(port) | 
 |  | 
 |     # Strip any IPv6 brackets (ssh does not support them). | 
 |     # We'll add them back later for rsync, scp, etc. | 
 |     if machine[0] == '[' and machine[-1] == ']': | 
 |         machine = machine[1:-1] | 
 |  | 
 |     if not machine or not user: | 
 |         raise ValueError | 
 |  | 
 |     return machine, user, password, port | 
 |  | 
 |  | 
 | def get_public_key(): | 
 |     """ | 
 |     Return a valid string ssh public key for the user executing autoserv or | 
 |     autotest. If there's no DSA or RSA public key, create a DSA keypair with | 
 |     ssh-keygen and return it. | 
 |     """ | 
 |  | 
 |     ssh_conf_path = os.path.expanduser('~/.ssh') | 
 |  | 
 |     dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub') | 
 |     dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa') | 
 |  | 
 |     rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub') | 
 |     rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa') | 
 |  | 
 |     has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \ | 
 |         os.path.isfile(dsa_private_key_path) | 
 |     has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \ | 
 |         os.path.isfile(rsa_private_key_path) | 
 |  | 
 |     if has_dsa_keypair: | 
 |         print('DSA keypair found, using it') | 
 |         public_key_path = dsa_public_key_path | 
 |  | 
 |     elif has_rsa_keypair: | 
 |         print('RSA keypair found, using it') | 
 |         public_key_path = rsa_public_key_path | 
 |  | 
 |     else: | 
 |         print('Neither RSA nor DSA keypair found, creating DSA ssh key pair') | 
 |         utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path) | 
 |         public_key_path = dsa_public_key_path | 
 |  | 
 |     public_key = open(public_key_path, 'r') | 
 |     public_key_str = public_key.read() | 
 |     public_key.close() | 
 |  | 
 |     return public_key_str |