302 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			302 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
| import os
 | |
| import threading
 | |
| import Queue
 | |
| import traceback
 | |
| import atexit
 | |
| import weakref
 | |
| import __future__
 | |
| 
 | |
| # note that the whole code of this module (as well as some
 | |
| # other modules) execute not only on the local side but 
 | |
| # also on any gateway's remote side.  On such remote sides
 | |
| # we cannot assume the py library to be there and 
 | |
| # InstallableGateway._remote_bootstrap_gateway() (located 
 | |
| # in register.py) will take care to send source fragments
 | |
| # to the other side.  Yes, it is fragile but we have a 
 | |
| # few tests that try to catch when we mess up. 
 | |
| 
 | |
| # XXX the following lines should not be here
 | |
| if 'ThreadOut' not in globals(): 
 | |
|     import py 
 | |
|     from py.code import Source
 | |
|     from py.__.execnet.channel import ChannelFactory, Channel
 | |
|     from py.__.execnet.message import Message
 | |
|     ThreadOut = py._thread.ThreadOut 
 | |
|     WorkerPool = py._thread.WorkerPool 
 | |
|     NamedThreadPool = py._thread.NamedThreadPool 
 | |
| 
 | |
| import os
 | |
| debug = 0 # open('/tmp/execnet-debug-%d' % os.getpid()  , 'wa')
 | |
| 
 | |
| sysex = (KeyboardInterrupt, SystemExit)
 | |
| 
 | |
| class Gateway(object):
 | |
|     _ThreadOut = ThreadOut 
 | |
|     remoteaddress = ""
 | |
|     def __init__(self, io, execthreads=None, _startcount=2): 
 | |
|         """ initialize core gateway, using the given 
 | |
|             inputoutput object and 'execthreads' execution
 | |
|             threads. 
 | |
|         """
 | |
|         global registered_cleanup
 | |
|         self._execpool = WorkerPool(maxthreads=execthreads)
 | |
|         self._io = io
 | |
|         self._outgoing = Queue.Queue()
 | |
|         self._channelfactory = ChannelFactory(self, _startcount)
 | |
|         if not registered_cleanup:
 | |
|             atexit.register(cleanup_atexit)
 | |
|             registered_cleanup = True
 | |
|         _active_sendqueues[self._outgoing] = True
 | |
|         self._pool = NamedThreadPool(receiver = self._thread_receiver, 
 | |
|                                      sender = self._thread_sender)
 | |
| 
 | |
|     def __repr__(self):
 | |
|         """ return string representing gateway type and status. """
 | |
|         addr = self.remoteaddress 
 | |
|         if addr:
 | |
|             addr = '[%s]' % (addr,)
 | |
|         else:
 | |
|             addr = ''
 | |
|         try:
 | |
|             r = (len(self._pool.getstarted('receiver'))
 | |
|                  and "receiving" or "not receiving")
 | |
|             s = (len(self._pool.getstarted('sender')) 
 | |
|                  and "sending" or "not sending")
 | |
|             i = len(self._channelfactory.channels())
 | |
|         except AttributeError:
 | |
|             r = s = "uninitialized"
 | |
|             i = "no"
 | |
|         return "<%s%s %s/%s (%s active channels)>" %(
 | |
|                 self.__class__.__name__, addr, r, s, i)
 | |
| 
 | |
| ##    def _local_trystopexec(self):
 | |
| ##        self._execpool.shutdown() 
 | |
| 
 | |
|     def _trace(self, *args):
 | |
|         if debug:
 | |
|             try:
 | |
|                 l = "\n".join(args).split(os.linesep)
 | |
|                 id = getid(self)
 | |
|                 for x in l:
 | |
|                     print >>debug, x
 | |
|                 debug.flush()
 | |
|             except sysex:
 | |
|                 raise
 | |
|             except:
 | |
|                 traceback.print_exc()
 | |
| 
 | |
|     def _traceex(self, excinfo):
 | |
|         try:
 | |
|             l = traceback.format_exception(*excinfo)
 | |
|             errortext = "".join(l)
 | |
|         except:
 | |
|             errortext = '%s: %s' % (excinfo[0].__name__,
 | |
|                                     excinfo[1])
 | |
|         self._trace(errortext)
 | |
| 
 | |
|     def _thread_receiver(self):
 | |
|         """ thread to read and handle Messages half-sync-half-async. """
 | |
|         try:
 | |
|             from sys import exc_info
 | |
|             while 1:
 | |
|                 try:
 | |
|                     msg = Message.readfrom(self._io)
 | |
|                     self._trace("received <- %r" % msg)
 | |
|                     msg.received(self)
 | |
|                 except sysex:
 | |
|                     break
 | |
|                 except EOFError:
 | |
|                     break
 | |
|                 except:
 | |
|                     self._traceex(exc_info())
 | |
|                     break 
 | |
|         finally:
 | |
|             self._send(None)
 | |
|             self._channelfactory._finished_receiving()
 | |
|             self._trace('leaving %r' % threading.currentThread())
 | |
| 
 | |
|     def _send(self, msg):
 | |
|         self._outgoing.put(msg)
 | |
| 
 | |
|     def _thread_sender(self):
 | |
|         """ thread to send Messages over the wire. """
 | |
|         try:
 | |
|             from sys import exc_info
 | |
|             while 1:
 | |
|                 msg = self._outgoing.get()
 | |
|                 try:
 | |
|                     if msg is None:
 | |
|                         self._io.close_write()
 | |
|                         break
 | |
|                     msg.writeto(self._io)
 | |
|                 except:
 | |
|                     excinfo = exc_info()
 | |
|                     self._traceex(excinfo)
 | |
|                     if msg is not None:
 | |
|                         msg.post_sent(self, excinfo)
 | |
|                     break
 | |
|                 else:
 | |
|                     self._trace('sent -> %r' % msg)
 | |
|                     msg.post_sent(self)
 | |
|         finally:
 | |
|             self._trace('leaving %r' % threading.currentThread())
 | |
| 
 | |
|     def _local_redirect_thread_output(self, outid, errid): 
 | |
|         l = []
 | |
|         for name, id in ('stdout', outid), ('stderr', errid): 
 | |
|             if id: 
 | |
|                 channel = self._channelfactory.new(outid)
 | |
|                 out = self._ThreadOut(sys, name)
 | |
|                 out.setwritefunc(channel.send) 
 | |
|                 l.append((out, channel))
 | |
|         def close(): 
 | |
|             for out, channel in l: 
 | |
|                 out.delwritefunc() 
 | |
|                 channel.close() 
 | |
|         return close 
 | |
| 
 | |
|     def _thread_executor(self, channel, (source, outid, errid)):
 | |
|         """ worker thread to execute source objects from the execution queue. """
 | |
|         from sys import exc_info
 | |
|         try:
 | |
|             loc = { 'channel' : channel }
 | |
|             self._trace("execution starts:", repr(source)[:50])
 | |
|             close = self._local_redirect_thread_output(outid, errid) 
 | |
|             try:
 | |
|                 co = compile(source+'\n', '', 'exec',
 | |
|                              __future__.CO_GENERATOR_ALLOWED)
 | |
|                 exec co in loc
 | |
|             finally:
 | |
|                 close() 
 | |
|                 self._trace("execution finished:", repr(source)[:50])
 | |
|         except (KeyboardInterrupt, SystemExit):
 | |
|             pass 
 | |
|         except:
 | |
|             excinfo = exc_info()
 | |
|             l = traceback.format_exception(*excinfo)
 | |
|             errortext = "".join(l)
 | |
|             channel.close(errortext)
 | |
|             self._trace(errortext)
 | |
|         else:
 | |
|             channel.close()
 | |
| 
 | |
|     def _local_schedulexec(self, channel, sourcetask): 
 | |
|         self._trace("dispatching exec")
 | |
|         self._execpool.dispatch(self._thread_executor, channel, sourcetask) 
 | |
| 
 | |
|     def _newredirectchannelid(self, callback): 
 | |
|         if callback is None: 
 | |
|             return  
 | |
|         if hasattr(callback, 'write'): 
 | |
|             callback = callback.write 
 | |
|         assert callable(callback) 
 | |
|         chan = self.newchannel()
 | |
|         chan.setcallback(callback)
 | |
|         return chan.id 
 | |
| 
 | |
|     # _____________________________________________________________________
 | |
|     #
 | |
|     # High Level Interface
 | |
|     # _____________________________________________________________________
 | |
|     #
 | |
|     def newchannel(self): 
 | |
|         """ return new channel object.  """ 
 | |
|         return self._channelfactory.new()
 | |
| 
 | |
|     def remote_exec(self, source, stdout=None, stderr=None): 
 | |
|         """ return channel object and connect it to a remote
 | |
|             execution thread where the given 'source' executes
 | |
|             and has the sister 'channel' object in its global 
 | |
|             namespace.  The callback functions 'stdout' and 
 | |
|             'stderr' get called on receival of remote 
 | |
|             stdout/stderr output strings. 
 | |
|         """
 | |
|         try:
 | |
|             source = str(Source(source))
 | |
|         except NameError: 
 | |
|             try: 
 | |
|                 import py 
 | |
|                 source = str(py.code.Source(source))
 | |
|             except ImportError: 
 | |
|                 pass 
 | |
|         channel = self.newchannel() 
 | |
|         outid = self._newredirectchannelid(stdout) 
 | |
|         errid = self._newredirectchannelid(stderr) 
 | |
|         self._send(Message.CHANNEL_OPEN(
 | |
|                     channel.id, (source, outid, errid)))
 | |
|         return channel 
 | |
| 
 | |
|     def _remote_redirect(self, stdout=None, stderr=None): 
 | |
|         """ return a handle representing a redirection of a remote 
 | |
|             end's stdout to a local file object.  with handle.close() 
 | |
|             the redirection will be reverted.   
 | |
|         """ 
 | |
|         clist = []
 | |
|         for name, out in ('stdout', stdout), ('stderr', stderr): 
 | |
|             if out: 
 | |
|                 outchannel = self.newchannel()
 | |
|                 outchannel.setcallback(getattr(out, 'write', out))
 | |
|                 channel = self.remote_exec(""" 
 | |
|                     import sys
 | |
|                     outchannel = channel.receive() 
 | |
|                     outchannel.gateway._ThreadOut(sys, %r).setdefaultwriter(outchannel.send)
 | |
|                 """ % name) 
 | |
|                 channel.send(outchannel)
 | |
|                 clist.append(channel)
 | |
|         for c in clist: 
 | |
|             c.waitclose() 
 | |
|         class Handle: 
 | |
|             def close(_): 
 | |
|                 for name, out in ('stdout', stdout), ('stderr', stderr): 
 | |
|                     if out: 
 | |
|                         c = self.remote_exec("""
 | |
|                             import sys
 | |
|                             channel.gateway._ThreadOut(sys, %r).resetdefault()
 | |
|                         """ % name) 
 | |
|                         c.waitclose() 
 | |
|         return Handle()
 | |
| 
 | |
|     def exit(self):
 | |
|         """ Try to stop all IO activity. """
 | |
|         try:
 | |
|             del _active_sendqueues[self._outgoing]
 | |
|         except KeyError:
 | |
|             pass
 | |
|         else:
 | |
|             self._send(None)
 | |
| 
 | |
|     def join(self, joinexec=True):
 | |
|         """ Wait for all IO (and by default all execution activity) 
 | |
|             to stop. 
 | |
|         """
 | |
|         current = threading.currentThread()
 | |
|         for x in self._pool.getstarted(): 
 | |
|             if x != current: 
 | |
|                 self._trace("joining %s" % x)
 | |
|                 x.join()
 | |
|         self._trace("joining sender/reciver threads finished, current %r" % current) 
 | |
|         if joinexec: 
 | |
|             self._execpool.join()
 | |
|             self._trace("joining execution threads finished, current %r" % current) 
 | |
| 
 | |
| def getid(gw, cache={}):
 | |
|     name = gw.__class__.__name__
 | |
|     try:
 | |
|         return cache.setdefault(name, {})[id(gw)]
 | |
|     except KeyError:
 | |
|         cache[name][id(gw)] = x = "%s:%s.%d" %(os.getpid(), gw.__class__.__name__, len(cache[name]))
 | |
|         return x
 | |
| 
 | |
| registered_cleanup = False
 | |
| _active_sendqueues = weakref.WeakKeyDictionary()
 | |
| def cleanup_atexit():
 | |
|     if debug:
 | |
|         print >>debug, "="*20 + "cleaning up" + "=" * 20
 | |
|         debug.flush()
 | |
|     while True:
 | |
|         try:
 | |
|             queue, ignored = _active_sendqueues.popitem()
 | |
|         except KeyError:
 | |
|             break
 | |
|         queue.put(None)
 |