blob: 20fe9c20c57bafffe9fd8c2f4464383d7238f655 [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2017-2023 The Khronos Group Inc.
# SPDX-License-Identifier: Apache-2.0
"""Generate a mapping of extension name -> all required extension names for
that extension, from dependencies in the API XML."""
import argparse
import errno
import xml.etree.ElementTree as etree
from pathlib import Path
from apiconventions import APIConventions
from parse_dependency import dependencyNames
class DiGraph:
"""A directed graph.
The implementation and API mimic that of networkx.DiGraph in networkx-1.11.
networkx implements graphs as nested dicts; it uses dicts all the way
down, no lists.
Some major differences between this implementation and that of
networkx-1.11 are:
* This omits edge and node attribute data, because we never use them
yet they add additional code complexity.
* This returns iterator objects when possible instead of collection
objects, because it simplifies the implementation and should provide
better performance.
"""
def __init__(self):
self.__nodes = {}
def add_node(self, node):
if node not in self.__nodes:
self.__nodes[node] = DiGraphNode()
def add_edge(self, src, dest):
self.add_node(src)
self.add_node(dest)
self.__nodes[src].adj.add(dest)
def nodes(self):
"""Iterate over the nodes in the graph."""
return self.__nodes.keys()
def descendants(self, node):
"""
Iterate over the nodes reachable from the given start node, excluding
the start node itself. Each node in the graph is yielded at most once.
"""
# Implementation detail: Do a breadth-first traversal because it is
# easier than depth-first.
# All nodes seen during traversal.
seen = set()
# The stack of nodes that need visiting.
visit_me = []
# Bootstrap the traversal.
seen.add(node)
for x in self.__nodes[node].adj:
if x not in seen:
seen.add(x)
visit_me.append(x)
while visit_me:
x = visit_me.pop()
assert x in seen
yield x
for y in self.__nodes[x].adj:
if y not in seen:
seen.add(y)
visit_me.append(y)
class DiGraphNode:
def __init__(self):
# Set of adjacent of nodes.
self.adj = set()
class ApiDependencies:
def __init__(self,
registry_path = None,
api_name = None):
"""Load an API registry and generate extension dependencies
registry_path - relative filename of XML registry. If not specified,
uses the API default.
api_name - API name for which to generate dependencies. Only
extensions supported for that API are considered.
"""
conventions = APIConventions()
if registry_path is None:
registry_path = conventions.registry_path
if api_name is None:
api_name = conventions.xml_api_name
self.allExts = set()
self.khrExts = set()
self.ratifiedExts = set()
self.graph = DiGraph()
self.extensions = {}
self.tree = etree.parse(registry_path)
# Loop over all supported extensions, creating a digraph of the
# extension dependencies in the 'depends' attribute, which is a
# boolean expression of core version and extension names.
# A static dependency tree can be constructed only by treating all
# extension names in the expression as dependencies, even though
# that may not be true if it is of form (ext OR ext).
# For the purpose these dependencies are used for - generating
# specifications with required dependencies included automatically -
# this will suffice.
# Separately tracks lists of all extensions and all KHR extensions,
# which are common specification targets.
for elem in self.tree.findall('extensions/extension'):
name = elem.get('name')
supported = elem.get('supported')
ratified = elem.get('ratified', '')
if api_name in supported.split(','):
self.allExts.add(name)
if 'KHR' in name:
self.khrExts.add(name)
if api_name in ratified.split(','):
self.ratifiedExts.add(name)
self.graph.add_node(name)
depends = elem.get('depends')
if depends:
# Walk a list of the leaf nodes (version and extension
# names) in the boolean expression.
for dep in dependencyNames(depends):
# Filter out version names, which are explicitly
# specified when building a specification.
if not conventions.is_api_version_name(dep):
self.graph.add_edge(name, dep)
else:
# Skip unsupported extensions
pass
def allExtensions(self):
"""Returns a set of all extensions in the graph"""
return self.allExts
def khrExtensions(self):
"""Returns a set of all KHR extensions in the graph"""
return self.khrExts
def ratifiedExtensions(self):
"""Returns a set of all ratified extensions in the graph"""
return self.ratifiedExts
def children(self, extension):
"""Returns a set of the dependencies of an extension.
Throws an exception if the extension is not in the graph."""
if extension not in self.allExts:
raise Exception(f'Extension {extension} not found in XML!')
return set(self.graph.descendants(extension))
# Test script
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-registry', action='store',
default=APIConventions().registry_path,
help='Use specified registry file instead of ' + APIConventions().registry_path)
parser.add_argument('-loops', action='store',
default=10, type=int,
help='Number of timing loops to run')
parser.add_argument('-test', action='store',
default=None,
help='Specify extension to find dependencies of')
args = parser.parse_args()
deps = ApiDependencies(args.registry)
print('KHR exts =', sorted(deps.khrExtensions()))
print('Ratified exts =', sorted(deps.ratifiedExtensions()))
import time
startTime = time.process_time()
for loop in range(args.loops):
deps = ApiDependencies(args.registry)
endTime = time.process_time()
deltaT = endTime - startTime
print('Total time = {} time/loop = {}'.format(deltaT, deltaT / args.loops))