| from .core import unify, reify # type: ignore[attr-defined] |
| from .variable import isvar |
| from .utils import _toposort, freeze |
| from .unification_tools import groupby, first # type: ignore[import] |
| |
| |
| class Dispatcher: |
| def __init__(self, name): |
| self.name = name |
| self.funcs = {} |
| self.ordering = [] |
| |
| def add(self, signature, func): |
| self.funcs[freeze(signature)] = func |
| self.ordering = ordering(self.funcs) |
| |
| def __call__(self, *args, **kwargs): |
| func, s = self.resolve(args) |
| return func(*args, **kwargs) |
| |
| def resolve(self, args): |
| n = len(args) |
| for signature in self.ordering: |
| if len(signature) != n: |
| continue |
| s = unify(freeze(args), signature) |
| if s is not False: |
| result = self.funcs[signature] |
| return result, s |
| raise NotImplementedError("No match found. \nKnown matches: " |
| + str(self.ordering) + "\nInput: " + str(args)) |
| |
| def register(self, *signature): |
| def _(func): |
| self.add(signature, func) |
| return self |
| return _ |
| |
| |
| class VarDispatcher(Dispatcher): |
| """ A dispatcher that calls functions with variable names |
| >>> # xdoctest: +SKIP |
| >>> d = VarDispatcher('d') |
| >>> x = var('x') |
| >>> @d.register('inc', x) |
| ... def f(x): |
| ... return x + 1 |
| >>> @d.register('double', x) |
| ... def f(x): |
| ... return x * 2 |
| >>> d('inc', 10) |
| 11 |
| >>> d('double', 10) |
| 20 |
| """ |
| def __call__(self, *args, **kwargs): |
| func, s = self.resolve(args) |
| d = {k.token: v for k, v in s.items()} |
| return func(**d) |
| |
| |
| global_namespace = {} # type: ignore[var-annotated] |
| |
| |
| def match(*signature, **kwargs): |
| namespace = kwargs.get('namespace', global_namespace) |
| dispatcher = kwargs.get('Dispatcher', Dispatcher) |
| |
| def _(func): |
| name = func.__name__ |
| |
| if name not in namespace: |
| namespace[name] = dispatcher(name) |
| d = namespace[name] |
| |
| d.add(signature, func) |
| |
| return d |
| return _ |
| |
| |
| def supercedes(a, b): |
| """ ``a`` is a more specific match than ``b`` """ |
| if isvar(b) and not isvar(a): |
| return True |
| s = unify(a, b) |
| if s is False: |
| return False |
| s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} |
| if reify(a, s) == a: |
| return True |
| if reify(b, s) == b: |
| return False |
| |
| |
| # Taken from multipledispatch |
| def edge(a, b, tie_breaker=hash): |
| """ A should be checked before B |
| Tie broken by tie_breaker, defaults to ``hash`` |
| """ |
| if supercedes(a, b): |
| if supercedes(b, a): |
| return tie_breaker(a) > tie_breaker(b) |
| else: |
| return True |
| return False |
| |
| |
| # Taken from multipledispatch |
| def ordering(signatures): |
| """ A sane ordering of signatures to check, first to last |
| Topological sort of edges as given by ``edge`` and ``supercedes`` |
| """ |
| signatures = list(map(tuple, signatures)) |
| edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] |
| edges = groupby(first, edges) |
| for s in signatures: |
| if s not in edges: |
| edges[s] = [] |
| edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] |
| return _toposort(edges) |