307 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			307 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
| #import os
 | |
| import struct
 | |
| from collections import deque
 | |
| 
 | |
| 
 | |
| class InvalidPacket(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| FLAG_NAK1 = 0xE0
 | |
| FLAG_NAK  = 0xE1
 | |
| FLAG_REG  = 0xE2
 | |
| FLAG_CFRM = 0xE3
 | |
| 
 | |
| FLAG_RANGE_START  = 0xE0
 | |
| FLAG_RANGE_STOP   = 0xE4
 | |
| 
 | |
| max_old_packets = 256      # must be <= 256
 | |
| 
 | |
| 
 | |
| class PipeLayer(object):
 | |
|     timeout = 1
 | |
|     headersize = 4
 | |
| 
 | |
|     def __init__(self):
 | |
|         #self.localid = os.urandom(4)
 | |
|         #self.remoteid = None
 | |
|         self.cur_time = 0
 | |
|         self.out_queue = deque()
 | |
|         self.out_nextseqid = 0
 | |
|         self.out_nextrepeattime = None
 | |
|         self.in_nextseqid = 0
 | |
|         self.in_outoforder = {}
 | |
|         self.out_oldpackets = deque()
 | |
|         self.out_flags = FLAG_REG
 | |
|         self.out_resend = 0
 | |
|         self.out_resend_skip = False
 | |
| 
 | |
|     def queue(self, data):
 | |
|         if data:
 | |
|             self.out_queue.appendleft(data)
 | |
| 
 | |
|     def queue_size(self):
 | |
|         total = 0
 | |
|         for data in self.out_queue:
 | |
|             total += len(data)
 | |
|         return total
 | |
| 
 | |
|     def in_sync(self):
 | |
|         return not self.out_queue and self.out_nextrepeattime is None
 | |
| 
 | |
|     def settime(self, curtime):
 | |
|         self.cur_time = curtime
 | |
|         if self.out_queue:
 | |
|             if len(self.out_oldpackets) < max_old_packets:
 | |
|                 return 0   # more data to send now
 | |
|         if self.out_nextrepeattime is not None:
 | |
|             return max(0, self.out_nextrepeattime - curtime)
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     def encode(self, maxlength):
 | |
|         #print ' '*self._dump_indent, '--- OUTQ', self.out_resend, self.out_queue
 | |
|         if len(self.out_oldpackets) >= max_old_packets:
 | |
|             # congestion, stalling
 | |
|             payload = 0
 | |
|         else:
 | |
|             payload = maxlength - 4
 | |
|             if payload <= 0:
 | |
|                 raise ValueError("encode(): buffer too small")
 | |
|         if (self.out_nextrepeattime is not None and
 | |
|             self.out_nextrepeattime <= self.cur_time):
 | |
|             # no ACK received so far, send a packet (possibly empty)
 | |
|             if not self.out_queue:
 | |
|                 payload = 0
 | |
|         else:
 | |
|             if not self.out_queue:   # no more data to send
 | |
|                 return None
 | |
|             if payload == 0:         # congestion
 | |
|                 return None
 | |
|         # prepare a packet
 | |
|         seqid = self.out_nextseqid
 | |
|         flags = self.out_flags
 | |
|         self.out_flags = FLAG_REG     # clear out the flags for the next time
 | |
|         if payload > 0:
 | |
|             self.out_nextseqid = (seqid + 1) & 0xFFFF
 | |
|             data = self.out_queue.pop()
 | |
|             packetlength = len(data)
 | |
|             if self.out_resend > 0:
 | |
|                 if packetlength > payload:
 | |
|                     raise ValueError("XXX need constant buffer size for now")
 | |
|                 self.out_resend -= 1
 | |
|                 if self.out_resend_skip:
 | |
|                     if self.out_resend > 0:
 | |
|                         self.out_queue.pop()
 | |
|                         self.out_resend -= 1
 | |
|                         self.out_nextseqid = (seqid + 2) & 0xFFFF
 | |
|                     self.out_resend_skip = False
 | |
|                 packetpayload = data
 | |
|             else:
 | |
|                 packet = []
 | |
|                 while packetlength <= payload:
 | |
|                     packet.append(data)
 | |
|                     if not self.out_queue:
 | |
|                         break
 | |
|                     data = self.out_queue.pop()
 | |
|                     packetlength += len(data)
 | |
|                 else:
 | |
|                     rest = len(data) + payload - packetlength
 | |
|                     packet.append(data[:rest])
 | |
|                     self.out_queue.append(data[rest:])
 | |
|                 packetpayload = ''.join(packet)
 | |
|                 self.out_oldpackets.appendleft(packetpayload)
 | |
|                 #print ' '*self._dump_indent, '--- OLDPK', self.out_oldpackets
 | |
|         else:
 | |
|             # a pure ACK packet, no payload
 | |
|             if self.out_oldpackets and flags == FLAG_REG:
 | |
|                 flags = FLAG_CFRM
 | |
|             packetpayload = ''
 | |
|         packet = struct.pack("!BBH", flags,
 | |
|                              self.in_nextseqid & 0xFF,
 | |
|                              seqid) + packetpayload
 | |
|         if self.out_oldpackets:
 | |
|             self.out_nextrepeattime = self.cur_time + self.timeout
 | |
|         else:
 | |
|             self.out_nextrepeattime = None
 | |
|         #self.dump('OUT', packet)
 | |
|         return packet
 | |
| 
 | |
|     def decode(self, rawdata):
 | |
|         if len(rawdata) < 4:
 | |
|             raise InvalidPacket
 | |
|         #print ' '*self._dump_indent, '------ out %d (+%d) in %d' % (self.out_nextseqid, self.out_resend, self.in_nextseqid)
 | |
|         #self.dump('IN ', rawdata)
 | |
|         in_flags, ack_seqid, in_seqid = struct.unpack("!BBH", rawdata[:4])
 | |
|         if not (FLAG_RANGE_START <= in_flags < FLAG_RANGE_STOP):
 | |
|             raise InvalidPacket
 | |
|         in_diff  = (in_seqid  - self.in_nextseqid ) & 0xFFFF
 | |
|         ack_diff = (self.out_nextseqid + self.out_resend - ack_seqid) & 0xFF
 | |
|         if in_diff >= max_old_packets:
 | |
|             return ''    # invalid, but can occur as a late repetition
 | |
|         if ack_diff != len(self.out_oldpackets):
 | |
|             # forget all acknowledged packets
 | |
|             if ack_diff > len(self.out_oldpackets):
 | |
|                 return ''   # invalid, but can occur with packet reordering
 | |
|             while len(self.out_oldpackets) > ack_diff:
 | |
|                 #print ' '*self._dump_indent, '--- POP', repr(self.out_oldpackets[-1])
 | |
|                 self.out_oldpackets.pop()
 | |
|             if self.out_oldpackets:
 | |
|                 self.out_nextrepeattime = self.cur_time + self.timeout
 | |
|             else:
 | |
|                 self.out_nextrepeattime = None   # all packets ACKed
 | |
|         if in_flags == FLAG_NAK or in_flags == FLAG_NAK1:
 | |
|             # this is a NAK: resend the old packets as far as they've not
 | |
|             # also been ACK'ed in the meantime (can occur with reordering)
 | |
|             while self.out_resend < len(self.out_oldpackets):
 | |
|                 self.out_queue.append(self.out_oldpackets[self.out_resend])
 | |
|                 self.out_resend += 1
 | |
|                 self.out_nextseqid = (self.out_nextseqid - 1) & 0xFFFF
 | |
|                 #print ' '*self._dump_indent, '--- REP', self.out_nextseqid, repr(self.out_queue[-1])
 | |
|             self.out_resend_skip = in_flags == FLAG_NAK1
 | |
|         elif in_flags == FLAG_CFRM:
 | |
|             # this is a CFRM: request for confirmation
 | |
|             self.out_nextrepeattime = self.cur_time
 | |
|         # receive this packet's payload if it is the next in the sequence
 | |
|         if in_diff == 0:
 | |
|             if len(rawdata) > 4:
 | |
|                 #print ' '*self._dump_indent, 'RECV ', self.in_nextseqid, repr(rawdata[4:])
 | |
|                 self.in_nextseqid = (self.in_nextseqid + 1) & 0xFFFF
 | |
|                 result = [rawdata[4:]]
 | |
|                 while self.in_nextseqid in self.in_outoforder:
 | |
|                     result.append(self.in_outoforder.pop(self.in_nextseqid))
 | |
|                     self.in_nextseqid = (self.in_nextseqid + 1) & 0xFFFF
 | |
|                 return ''.join(result)
 | |
|         else:
 | |
|             # we missed at least one intermediate packet: send a NAK
 | |
|             if len(rawdata) > 4:
 | |
|                 self.in_outoforder[in_seqid] = rawdata[4:]
 | |
|             if ((self.in_nextseqid + 1) & 0xFFFF) in self.in_outoforder:
 | |
|                 self.out_flags = FLAG_NAK1
 | |
|             else:
 | |
|                 self.out_flags = FLAG_NAK
 | |
|             self.out_nextrepeattime = self.cur_time
 | |
|         return ''
 | |
| 
 | |
|     _dump_indent = 0
 | |
|     def dump(self, dir, rawdata):
 | |
|         in_flags, ack_seqid, in_seqid = struct.unpack("!BBH", rawdata[:4])
 | |
|         print ' ' * self._dump_indent, dir,
 | |
|         if in_flags == FLAG_NAK:
 | |
|             print 'NAK',
 | |
|         elif in_flags == FLAG_NAK1:
 | |
|             print 'NAK1',
 | |
|         elif in_flags == FLAG_CFRM:
 | |
|             print 'CFRM',
 | |
|         #print ack_seqid, in_seqid, '(%d bytes)' % (len(rawdata)-4,)
 | |
|         print ack_seqid, in_seqid, repr(rawdata[4:])
 | |
| 
 | |
| 
 | |
| def pipe_over_udp(udpsock, send_fd=-1, recv_fd=-1,
 | |
|                   timeout=1.0, inactivity_timeout=None):
 | |
|     """Example: send all data showing up in send_fd over the given UDP
 | |
|     socket, and write incoming data into recv_fd.  The send_fd and
 | |
|     recv_fd are plain file descriptors.  When an EOF is read from
 | |
|     send_fd, this function returns (after making sure that all data was
 | |
|     received by the remote side).
 | |
|     """
 | |
|     import os
 | |
|     from select import select
 | |
|     from time import time
 | |
|     p = PipeLayer()
 | |
|     p.timeout = timeout
 | |
|     iwtdlist = [udpsock]
 | |
|     if send_fd >= 0:
 | |
|         iwtdlist.append(send_fd)
 | |
|     running = True
 | |
|     while running or not p.in_sync():
 | |
|         delay = delay1 = p.settime(time())
 | |
|         if delay is None:
 | |
|             delay = inactivity_timeout
 | |
|         iwtd, owtd, ewtd = select(iwtdlist, [], [], delay)
 | |
|         if iwtd:
 | |
|             if send_fd in iwtd:
 | |
|                 data = os.read(send_fd, 1500 - p.headersize)
 | |
|                 if not data:
 | |
|                     # EOF
 | |
|                     iwtdlist.remove(send_fd)
 | |
|                     running = False
 | |
|                 else:
 | |
|                     #print 'queue', len(data)
 | |
|                     p.queue(data)
 | |
|             if udpsock in iwtd:
 | |
|                 packet = udpsock.recv(65535)
 | |
|                 #print 'decode', len(packet)
 | |
|                 p.settime(time())
 | |
|                 data = p.decode(packet)
 | |
|                 i = 0
 | |
|                 while i < len(data):
 | |
|                     i += os.write(recv_fd, data[i:])
 | |
|         elif delay1 is None:
 | |
|             break    # long inactivity
 | |
|         p.settime(time())
 | |
|         packet = p.encode(1500)
 | |
|         if packet:
 | |
|             #print 'send', len(packet)
 | |
|             #if os.urandom(1) >= '\x08':    # emulate packet losses
 | |
|             udpsock.send(packet)
 | |
| 
 | |
| 
 | |
| class PipeOverUdp(object):
 | |
| 
 | |
|     def __init__(self, udpsock, timeout=1.0):
 | |
|         import thread, os
 | |
|         self.os = os
 | |
|         self.sendpipe = os.pipe()
 | |
|         self.recvpipe = os.pipe()
 | |
|         thread.start_new_thread(pipe_over_udp, (udpsock,
 | |
|                                                 self.sendpipe[0],
 | |
|                                                 self.recvpipe[1],
 | |
|                                                 timeout))
 | |
| 
 | |
|     def __del__(self):
 | |
|         os = self.os
 | |
|         if self.sendpipe:
 | |
|             os.close(self.sendpipe[0])
 | |
|             os.close(self.sendpipe[1])
 | |
|             self.sendpipe = None
 | |
|         if self.recvpipe:
 | |
|             os.close(self.recvpipe[0])
 | |
|             os.close(self.recvpipe[1])
 | |
|             self.recvpipe = None
 | |
| 
 | |
|     close = __del__
 | |
| 
 | |
|     def send(self, data):
 | |
|         if not self.sendpipe:
 | |
|             raise IOError("I/O operation on a closed PipeOverUdp")
 | |
|         return self.os.write(self.sendpipe[1], data)
 | |
| 
 | |
|     def sendall(self, data):
 | |
|         i = 0
 | |
|         while i < len(data):
 | |
|             i += self.send(data[i:])
 | |
| 
 | |
|     def recv(self, bufsize):
 | |
|         if not self.recvpipe:
 | |
|             raise IOError("I/O operation on a closed PipeOverUdp")
 | |
|         return self.os.read(self.recvpipe[0], bufsize)
 | |
| 
 | |
|     def recvall(self, bufsize):
 | |
|         buf = []
 | |
|         while bufsize > 0:
 | |
|             data = self.recv(bufsize)
 | |
|             buf.append(data)
 | |
|             bufsize -= len(data)
 | |
|         return ''.join(buf)
 | |
| 
 | |
|     def fileno(self):
 | |
|         if not self.recvpipe:
 | |
|             raise IOError("I/O operation on a closed PipeOverUdp")
 | |
|         return self.recvpipe[0]
 | |
| 
 | |
|     def ofileno(self):
 | |
|         if not self.sendpipe:
 | |
|             raise IOError("I/O operation on a closed PipeOverUdp")
 | |
|         return self.sendpipe[1]
 |