simplify _scan_plugin implementation and store argnames on HookCaller
This commit is contained in:
parent
351931d5ca
commit
e635f9f9b2
|
@ -68,7 +68,7 @@ class TagTracerSub:
|
||||||
return self.__class__(self.root, self.tags + (name,))
|
return self.__class__(self.root, self.tags + (name,))
|
||||||
|
|
||||||
class PluginManager(object):
|
class PluginManager(object):
|
||||||
def __init__(self, hookspecs=None):
|
def __init__(self, hookspecs=None, prefix="pytest_"):
|
||||||
self._name2plugin = {}
|
self._name2plugin = {}
|
||||||
self._listattrcache = {}
|
self._listattrcache = {}
|
||||||
self._plugins = []
|
self._plugins = []
|
||||||
|
@ -77,7 +77,7 @@ class PluginManager(object):
|
||||||
self.trace = TagTracer().get("pluginmanage")
|
self.trace = TagTracer().get("pluginmanage")
|
||||||
self._plugin_distinfo = []
|
self._plugin_distinfo = []
|
||||||
self._shutdown = []
|
self._shutdown = []
|
||||||
self.hook = HookRelay(hookspecs or [], pm=self)
|
self.hook = HookRelay(hookspecs or [], pm=self, prefix=prefix)
|
||||||
|
|
||||||
def do_configure(self, config):
|
def do_configure(self, config):
|
||||||
# backward compatibility
|
# backward compatibility
|
||||||
|
@ -384,22 +384,25 @@ class HookRelay:
|
||||||
self._hookspecs = []
|
self._hookspecs = []
|
||||||
self._pm = pm
|
self._pm = pm
|
||||||
self.trace = pm.trace.root.get("hook")
|
self.trace = pm.trace.root.get("hook")
|
||||||
|
self.prefix = prefix
|
||||||
for hookspec in hookspecs:
|
for hookspec in hookspecs:
|
||||||
self._addhooks(hookspec, prefix)
|
self._addhooks(hookspec, prefix)
|
||||||
|
|
||||||
def _addhooks(self, hookspecs, prefix):
|
def _addhooks(self, hookspec, prefix):
|
||||||
self._hookspecs.append(hookspecs)
|
self._hookspecs.append(hookspec)
|
||||||
added = False
|
added = False
|
||||||
for name, method in vars(hookspecs).items():
|
for name in dir(hookspec):
|
||||||
if name.startswith(prefix):
|
if name.startswith(prefix):
|
||||||
|
method = getattr(hookspec, name)
|
||||||
firstresult = getattr(method, 'firstresult', False)
|
firstresult = getattr(method, 'firstresult', False)
|
||||||
hc = HookCaller(self, name, firstresult=firstresult)
|
hc = HookCaller(self, name, firstresult=firstresult,
|
||||||
|
argnames=varnames(method))
|
||||||
setattr(self, name, hc)
|
setattr(self, name, hc)
|
||||||
added = True
|
added = True
|
||||||
#print ("setting new hook", name)
|
#print ("setting new hook", name)
|
||||||
if not added:
|
if not added:
|
||||||
raise ValueError("did not find new %r hooks in %r" %(
|
raise ValueError("did not find new %r hooks in %r" %(
|
||||||
prefix, hookspecs,))
|
prefix, hookspec,))
|
||||||
|
|
||||||
def _getcaller(self, name, plugins):
|
def _getcaller(self, name, plugins):
|
||||||
caller = getattr(self, name)
|
caller = getattr(self, name)
|
||||||
|
@ -409,62 +412,44 @@ class HookRelay:
|
||||||
return caller
|
return caller
|
||||||
|
|
||||||
def _scan_plugin(self, plugin):
|
def _scan_plugin(self, plugin):
|
||||||
methods = collectattr(plugin)
|
def fail(msg, *args):
|
||||||
hooks = {}
|
name = getattr(plugin, '__name__', plugin)
|
||||||
for hookspec in self._hookspecs:
|
raise PluginValidationError("plugin %r\n%s" %(name, msg % args))
|
||||||
hooks.update(collectattr(hookspec))
|
|
||||||
|
|
||||||
stringio = py.io.TextIO()
|
for name in dir(plugin):
|
||||||
def Print(*args):
|
if not name.startswith(self.prefix):
|
||||||
if args:
|
|
||||||
stringio.write(" ".join(map(str, args)))
|
|
||||||
stringio.write("\n")
|
|
||||||
|
|
||||||
fail = False
|
|
||||||
while methods:
|
|
||||||
name, method = methods.popitem()
|
|
||||||
#print "checking", name
|
|
||||||
if isgenerichook(name):
|
|
||||||
continue
|
continue
|
||||||
if name not in hooks:
|
hook = getattr(self, name, None)
|
||||||
if not getattr(method, 'optionalhook', False):
|
method = getattr(plugin, name)
|
||||||
Print("found unknown hook:", name)
|
if hook is None:
|
||||||
fail = True
|
is_optional = getattr(method, 'optionalhook', False)
|
||||||
else:
|
if not isgenerichook(name) and not is_optional:
|
||||||
#print "checking", method
|
fail("found unknown hook: %r", name)
|
||||||
method_args = list(varnames(method))
|
continue
|
||||||
if '__multicall__' in method_args:
|
for arg in varnames(method):
|
||||||
method_args.remove('__multicall__')
|
if arg not in hook.argnames:
|
||||||
hook = hooks[name]
|
fail("argument %r not available\n"
|
||||||
hookargs = varnames(hook)
|
"actual definition: %s\n"
|
||||||
for arg in method_args:
|
"available hookargs: %s",
|
||||||
if arg not in hookargs:
|
arg, formatdef(method),
|
||||||
Print("argument %r not available" %(arg, ))
|
", ".join(hook.argnames))
|
||||||
Print("actual definition: %s" %(formatdef(method)))
|
getattr(self, name).clear_method_cache()
|
||||||
Print("available hook arguments: %s" %
|
|
||||||
", ".join(hookargs))
|
|
||||||
fail = True
|
|
||||||
break
|
|
||||||
#if not fail:
|
|
||||||
# print "matching hook:", formatdef(method)
|
|
||||||
getattr(self, name).clear_method_cache()
|
|
||||||
|
|
||||||
if fail:
|
|
||||||
name = getattr(plugin, '__name__', plugin)
|
|
||||||
raise PluginValidationError("%s:\n%s" % (name, stringio.getvalue()))
|
|
||||||
|
|
||||||
|
|
||||||
class HookCaller:
|
class HookCaller:
|
||||||
def __init__(self, hookrelay, name, firstresult, methods=None):
|
def __init__(self, hookrelay, name, firstresult, argnames, methods=None):
|
||||||
self.hookrelay = hookrelay
|
self.hookrelay = hookrelay
|
||||||
self.name = name
|
self.name = name
|
||||||
self.firstresult = firstresult
|
self.firstresult = firstresult
|
||||||
self.trace = self.hookrelay.trace
|
self.trace = self.hookrelay.trace
|
||||||
self.methods = methods
|
self.methods = methods
|
||||||
|
self.argnames = ["__multicall__"]
|
||||||
|
self.argnames.extend(argnames)
|
||||||
|
assert "self" not in argnames
|
||||||
|
|
||||||
def new_cached_caller(self, methods):
|
def new_cached_caller(self, methods):
|
||||||
return HookCaller(self.hookrelay, self.name, self.firstresult,
|
return HookCaller(self.hookrelay, self.name, self.firstresult,
|
||||||
methods=methods)
|
argnames=self.argnames, methods=methods)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<HookCaller %r>" %(self.name,)
|
return "<HookCaller %r>" %(self.name,)
|
||||||
|
@ -474,14 +459,13 @@ class HookCaller:
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
methods = self.methods
|
methods = self.methods
|
||||||
if self.methods is None:
|
if methods is None:
|
||||||
self.methods = methods = self.hookrelay._pm.listattr(self.name)
|
self.methods = methods = self.hookrelay._pm.listattr(self.name)
|
||||||
methods = self.methods
|
|
||||||
return self._docall(methods, kwargs)
|
return self._docall(methods, kwargs)
|
||||||
|
|
||||||
def callextra(self, methods, **kwargs):
|
def callextra(self, methods, **kwargs):
|
||||||
if self.methods is None:
|
#if self.methods is None:
|
||||||
self.reload_methods()
|
# self.reload_methods()
|
||||||
return self._docall(self.methods + methods, kwargs)
|
return self._docall(self.methods + methods, kwargs)
|
||||||
|
|
||||||
def _docall(self, methods, kwargs):
|
def _docall(self, methods, kwargs):
|
||||||
|
|
|
@ -630,6 +630,18 @@ class TestHookRelay:
|
||||||
assert l == [4]
|
assert l == [4]
|
||||||
assert not hasattr(mcm, 'world')
|
assert not hasattr(mcm, 'world')
|
||||||
|
|
||||||
|
def test_argmismatch(self):
|
||||||
|
class Api:
|
||||||
|
def hello(self, arg):
|
||||||
|
"api hook 1"
|
||||||
|
pm = PluginManager(Api, prefix="he")
|
||||||
|
class Plugin:
|
||||||
|
def hello(self, argwrong):
|
||||||
|
return arg + 1
|
||||||
|
with pytest.raises(PluginValidationError) as exc:
|
||||||
|
pm.register(Plugin())
|
||||||
|
assert "argwrong" in str(exc.value)
|
||||||
|
|
||||||
def test_only_kwargs(self):
|
def test_only_kwargs(self):
|
||||||
pm = PluginManager()
|
pm = PluginManager()
|
||||||
class Api:
|
class Api:
|
||||||
|
|
Loading…
Reference in New Issue