simplify _scan_plugin implementation and store argnames on HookCaller

This commit is contained in:
holger krekel 2014-10-01 13:57:35 +02:00
parent 351931d5ca
commit e635f9f9b2
2 changed files with 50 additions and 54 deletions

View File

@ -68,7 +68,7 @@ class TagTracerSub:
return self.__class__(self.root, self.tags + (name,)) return self.__class__(self.root, self.tags + (name,))
class PluginManager(object): class PluginManager(object):
def __init__(self, hookspecs=None): def __init__(self, hookspecs=None, prefix="pytest_"):
self._name2plugin = {} self._name2plugin = {}
self._listattrcache = {} self._listattrcache = {}
self._plugins = [] self._plugins = []
@ -77,7 +77,7 @@ class PluginManager(object):
self.trace = TagTracer().get("pluginmanage") self.trace = TagTracer().get("pluginmanage")
self._plugin_distinfo = [] self._plugin_distinfo = []
self._shutdown = [] self._shutdown = []
self.hook = HookRelay(hookspecs or [], pm=self) self.hook = HookRelay(hookspecs or [], pm=self, prefix=prefix)
def do_configure(self, config): def do_configure(self, config):
# backward compatibility # backward compatibility
@ -384,22 +384,25 @@ class HookRelay:
self._hookspecs = [] self._hookspecs = []
self._pm = pm self._pm = pm
self.trace = pm.trace.root.get("hook") self.trace = pm.trace.root.get("hook")
self.prefix = prefix
for hookspec in hookspecs: for hookspec in hookspecs:
self._addhooks(hookspec, prefix) self._addhooks(hookspec, prefix)
def _addhooks(self, hookspecs, prefix): def _addhooks(self, hookspec, prefix):
self._hookspecs.append(hookspecs) self._hookspecs.append(hookspec)
added = False added = False
for name, method in vars(hookspecs).items(): for name in dir(hookspec):
if name.startswith(prefix): if name.startswith(prefix):
method = getattr(hookspec, name)
firstresult = getattr(method, 'firstresult', False) firstresult = getattr(method, 'firstresult', False)
hc = HookCaller(self, name, firstresult=firstresult) hc = HookCaller(self, name, firstresult=firstresult,
argnames=varnames(method))
setattr(self, name, hc) setattr(self, name, hc)
added = True added = True
#print ("setting new hook", name) #print ("setting new hook", name)
if not added: if not added:
raise ValueError("did not find new %r hooks in %r" %( raise ValueError("did not find new %r hooks in %r" %(
prefix, hookspecs,)) prefix, hookspec,))
def _getcaller(self, name, plugins): def _getcaller(self, name, plugins):
caller = getattr(self, name) caller = getattr(self, name)
@ -409,62 +412,44 @@ class HookRelay:
return caller return caller
def _scan_plugin(self, plugin): def _scan_plugin(self, plugin):
methods = collectattr(plugin) def fail(msg, *args):
hooks = {} name = getattr(plugin, '__name__', plugin)
for hookspec in self._hookspecs: raise PluginValidationError("plugin %r\n%s" %(name, msg % args))
hooks.update(collectattr(hookspec))
stringio = py.io.TextIO() for name in dir(plugin):
def Print(*args): if not name.startswith(self.prefix):
if args:
stringio.write(" ".join(map(str, args)))
stringio.write("\n")
fail = False
while methods:
name, method = methods.popitem()
#print "checking", name
if isgenerichook(name):
continue continue
if name not in hooks: hook = getattr(self, name, None)
if not getattr(method, 'optionalhook', False): method = getattr(plugin, name)
Print("found unknown hook:", name) if hook is None:
fail = True is_optional = getattr(method, 'optionalhook', False)
else: if not isgenerichook(name) and not is_optional:
#print "checking", method fail("found unknown hook: %r", name)
method_args = list(varnames(method)) continue
if '__multicall__' in method_args: for arg in varnames(method):
method_args.remove('__multicall__') if arg not in hook.argnames:
hook = hooks[name] fail("argument %r not available\n"
hookargs = varnames(hook) "actual definition: %s\n"
for arg in method_args: "available hookargs: %s",
if arg not in hookargs: arg, formatdef(method),
Print("argument %r not available" %(arg, )) ", ".join(hook.argnames))
Print("actual definition: %s" %(formatdef(method))) getattr(self, name).clear_method_cache()
Print("available hook arguments: %s" %
", ".join(hookargs))
fail = True
break
#if not fail:
# print "matching hook:", formatdef(method)
getattr(self, name).clear_method_cache()
if fail:
name = getattr(plugin, '__name__', plugin)
raise PluginValidationError("%s:\n%s" % (name, stringio.getvalue()))
class HookCaller: class HookCaller:
def __init__(self, hookrelay, name, firstresult, methods=None): def __init__(self, hookrelay, name, firstresult, argnames, methods=None):
self.hookrelay = hookrelay self.hookrelay = hookrelay
self.name = name self.name = name
self.firstresult = firstresult self.firstresult = firstresult
self.trace = self.hookrelay.trace self.trace = self.hookrelay.trace
self.methods = methods self.methods = methods
self.argnames = ["__multicall__"]
self.argnames.extend(argnames)
assert "self" not in argnames
def new_cached_caller(self, methods): def new_cached_caller(self, methods):
return HookCaller(self.hookrelay, self.name, self.firstresult, return HookCaller(self.hookrelay, self.name, self.firstresult,
methods=methods) argnames=self.argnames, methods=methods)
def __repr__(self): def __repr__(self):
return "<HookCaller %r>" %(self.name,) return "<HookCaller %r>" %(self.name,)
@ -474,14 +459,13 @@ class HookCaller:
def __call__(self, **kwargs): def __call__(self, **kwargs):
methods = self.methods methods = self.methods
if self.methods is None: if methods is None:
self.methods = methods = self.hookrelay._pm.listattr(self.name) self.methods = methods = self.hookrelay._pm.listattr(self.name)
methods = self.methods
return self._docall(methods, kwargs) return self._docall(methods, kwargs)
def callextra(self, methods, **kwargs): def callextra(self, methods, **kwargs):
if self.methods is None: #if self.methods is None:
self.reload_methods() # self.reload_methods()
return self._docall(self.methods + methods, kwargs) return self._docall(self.methods + methods, kwargs)
def _docall(self, methods, kwargs): def _docall(self, methods, kwargs):

View File

@ -630,6 +630,18 @@ class TestHookRelay:
assert l == [4] assert l == [4]
assert not hasattr(mcm, 'world') assert not hasattr(mcm, 'world')
def test_argmismatch(self):
class Api:
def hello(self, arg):
"api hook 1"
pm = PluginManager(Api, prefix="he")
class Plugin:
def hello(self, argwrong):
return arg + 1
with pytest.raises(PluginValidationError) as exc:
pm.register(Plugin())
assert "argwrong" in str(exc.value)
def test_only_kwargs(self): def test_only_kwargs(self):
pm = PluginManager() pm = PluginManager()
class Api: class Api: