diff --git a/py/test/rsession/hostmanage.py b/py/test/rsession/hostmanage.py index 13c629f26..8dc2dfddd 100644 --- a/py/test/rsession/hostmanage.py +++ b/py/test/rsession/hostmanage.py @@ -97,13 +97,16 @@ class HostRSync(py.execnet.RSync): return True # added the target class HostManager(object): - def __init__(self, sshhosts, config): - self.sshhosts = sshhosts + def __init__(self, config, hosts=None): self.config = config + if hosts is None: + hosts = self.config.getvalue("dist_hosts") + hosts = [HostInfo(x) for x in hosts] + self.hosts = hosts def prepare_gateways(self): dist_remotepython = self.config.getvalue("dist_remotepython") - for host in self.sshhosts: + for host in self.hosts: host.initgateway(python=dist_remotepython) host.gw.host = host @@ -116,7 +119,7 @@ class HostManager(object): rsync = HostRSync() for root in roots: destrelpath = root.relto(self.config.topdir) - for host in self.sshhosts: + for host in self.hosts: reporter(repevent.HostRSyncing(host)) def donecallback(): reporter(repevent.HostReady(host)) @@ -131,7 +134,7 @@ class HostManager(object): def setup_nodes(self, reporter): nodes = [] - for host in self.sshhosts: + for host in self.hosts: if hasattr(host.gw, 'remote_exec'): # otherwise dummy for tests :/ ch = setup_slave(host, self.config) nodes.append(MasterNode(ch, reporter)) diff --git a/py/test/rsession/rsession.py b/py/test/rsession/rsession.py index a4564bf9f..c72819985 100644 --- a/py/test/rsession/rsession.py +++ b/py/test/rsession/rsession.py @@ -29,7 +29,7 @@ class AbstractSession(Session): option.startserver = True super(AbstractSession, self).fixoptions() - def init_reporter(self, reporter, sshhosts, reporter_class, arg=""): + def init_reporter(self, reporter, hosts, reporter_class, arg=""): """ This initialises so called `reporter` class, which will handle all event presenting to user. Does not get called if main received custom reporter @@ -58,9 +58,9 @@ class AbstractSession(Session): from py.__.test.rsession.rest import RestReporter reporter_class = RestReporter if arg: - reporter_instance = reporter_class(self.config, sshhosts) + reporter_instance = reporter_class(self.config, hosts) else: - reporter_instance = reporter_class(self.config, sshhosts) + reporter_instance = reporter_class(self.config, hosts) reporter = reporter_instance.report else: startserverflag = False @@ -125,16 +125,15 @@ class RSession(AbstractSession): """ main loop for running tests. """ args = self.config.args - sshhosts = self._getconfighosts() + hm = HostManager(self.config) reporter, startserverflag = self.init_reporter(reporter, - sshhosts, RemoteReporter) + hm.hosts, RemoteReporter) reporter, checkfun = self.wrap_reporter(reporter) - reporter(repevent.TestStarted(sshhosts)) + reporter(repevent.TestStarted(hm.hosts)) - hostmanager = HostManager(sshhosts, self.config) try: - nodes = hostmanager.init_hosts(reporter) + nodes = hm.init_hosts(reporter) reporter(repevent.RsyncFinished()) try: self.dispatch_tests(nodes, reporter, checkfun) @@ -162,10 +161,6 @@ class RSession(AbstractSession): self.kill_server(startserverflag) raise - def _getconfighosts(self): - return [HostInfo(spec) for spec in - self.config.getvalue("dist_hosts")] - def dispatch_tests(self, nodes, reporter, checkfun): colitems = self.config.getcolitems() keyword = self.config.option.keyword @@ -179,17 +174,17 @@ class LSession(AbstractSession): def main(self, reporter=None, runner=None): # check out if used options makes any sense args = self.config.args - - sshhosts = [HostInfo('localhost')] # this is just an info to reporter - + + hm = HostManager(self.config, hosts=[HostInfo('localhost')]) + hosts = hm.hosts if not self.config.option.nomagic: py.magic.invoke(assertion=1) reporter, startserverflag = self.init_reporter(reporter, - sshhosts, LocalReporter, args[0]) + hosts, LocalReporter, args[0]) reporter, checkfun = self.wrap_reporter(reporter) - reporter(repevent.TestStarted(sshhosts)) + reporter(repevent.TestStarted(hosts)) colitems = self.config.getcolitems() reporter(repevent.RsyncFinished()) diff --git a/py/test/rsession/testing/test_hostmanage.py b/py/test/rsession/testing/test_hostmanage.py index 48d8a868d..ed1442d4e 100644 --- a/py/test/rsession/testing/test_hostmanage.py +++ b/py/test/rsession/testing/test_hostmanage.py @@ -109,11 +109,17 @@ class TestSyncing(DirSetup): assert not res2 class TestHostManager(DirSetup): + def test_hostmanager_custom_hosts(self): + config = py.test.config._reparse([self.source]) + hm = HostManager(config, hosts=[1,2,3]) + assert hm.hosts == [1,2,3] + def test_hostmanager_init_rsync_topdir(self): dir2 = self.source.ensure("dir1", "dir2", dir=1) dir2.ensure("hello") config = py.test.config._reparse([self.source]) - hm = HostManager([HostInfo("localhost:" + str(self.dest))], config) + hm = HostManager(config, + hosts=[HostInfo("localhost:" + str(self.dest))]) events = [] hm.init_rsync(reporter=events.append) assert self.dest.join("dir1").check() @@ -128,7 +134,8 @@ class TestHostManager(DirSetup): dist_rsync_roots = ['dir1/dir2'] """)) config = py.test.config._reparse([self.source]) - hm = HostManager([HostInfo("localhost:" + str(self.dest))], config) + hm = HostManager(config, + hosts=[HostInfo("localhost:" + str(self.dest))]) events = [] hm.init_rsync(reporter=events.append) assert self.dest.join("dir1").check() diff --git a/py/test/rsession/testing/test_rsession.py b/py/test/rsession/testing/test_rsession.py index 25d64c4e1..992a6759a 100644 --- a/py/test/rsession/testing/test_rsession.py +++ b/py/test/rsession/testing/test_rsession.py @@ -121,7 +121,7 @@ class TestRSessionRemote(DirSetup): teardown_events = [] tmpdir = py.test.ensuretemp("emptyconftest") config = py.test.config._reparse([tmpdir]) - hm = HostManager(hosts, config) + hm = HostManager(config, hosts) nodes = hm.init_hosts(setup_events.append) hm.teardown_hosts(teardown_events.append, [node.channel for node in nodes], nodes) @@ -146,7 +146,7 @@ class TestRSessionRemote(DirSetup): allevents = [] config = py.test.config._reparse([]) - hm = HostManager(hosts, config) + hm = HostManager(config, hosts=hosts) nodes = hm.init_hosts(allevents.append) from py.__.test.rsession.testing.test_executor \