#!/usr/bin/env python # -*-mode: python; coding: utf-8 -*- # TODO: # Support for limiting data sizes, max number of connections from the same IP etc import sys import time import socket import select import struct # Connection states, both for client and server connection. # Client cycle: STATE_READING, WAITING, WRITING # Server cycle: WAITING, WRITING, STATE_READING STATE_READING = 0 # Reading record STATE_WAITING = 2 # Waiting for server response callback, or client mission. STATE_WRITING = 3 # Writing response to client or request to server STATE_EOF = 4 # EOF while reading # Constants FRAG_HEADER_LEN = 4 FRAG_MAX_SIZE = 2**31 - 1 FRAG_SIZE = FRAG_MAX_SIZE # Size of newly created fragments class ProxyEngine: def __init__(self): self.connections = [] # Client or server connections self.proxies = [] # Proxy objects def add_proxy(self, bind_address, port, host, hostport): """Add a new proxy""" proxy = Proxy(self, bind_address, port, host, hostport) self.proxies.append(proxy) self.connections.append(proxy.srv) def add_connection(self, conn): """Add a new connection""" self.connections.append(conn) class Proxy: def __init__(self, pe, bind_address, port, host, hostport): self.pe = pe self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.bind((bind_address, port)) self.sock.listen(1) self.srv = ServerConnection(host, hostport) def fileno(self): """Return the sockets fileno""" return self.sock.fileno() def handle_read(self): """Accept a new connection. Return a new ClientConnection""" sock, addr = self.sock.accept() self.pe.add_connection(ClientConnection(sock, addr, self.srv)) class ServerCall: def __init__(self, data, callback): self.data = data self.callback = callback class RPCConnection: def __init__(self): self.record = "" # Current record, as stream with RMs self.sndbuf = None self.sock = None def set_sock(self, sock): """Set socket to use""" self.sock = sock self.sndbuf = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) def eof_event(self): """recv/send discovered that the connection was closed""" self.state = STATE_EOF self.sock.close() def assert_sock(self): """Make sure a socket is available. May be overridden.""" assert self.sock is not None def readable(self): """Returns true if connection wants to read""" return self.state is STATE_READING def writable(self): """Returns true if connection wants to write""" return self.state is STATE_WRITING def eof(self): """Returns true if EOF has been detected""" return self.state is STATE_EOF def fileno(self): """Return the sockets fileno""" self.assert_sock() return self.sock.fileno() def write_record(self): """Write RPC record. Returns true when everything is written""" self.assert_sock() # We can write up to SO_SNDBUF without risk blocking wrote = self.sock.send(buffer(self.record, 0, self.sndbuf)) self.record = self.record[wrote:] return len(self.record) == 0 def frag_length(self, head): """Return the length of a fragment, including header""" assert len(head) == FRAG_HEADER_LEN x = struct.unpack('>L', head)[0] return int(x & 0x7fffffff) + FRAG_HEADER_LEN def frag_last(self, head): """Return true if last flag is set""" assert len(head) == FRAG_HEADER_LEN x = struct.unpack('>L', head)[0] return ((x & 0x80000000L) != 0) def rm_stream(self, stream): """Record-mark a data stream""" fragpos = 0 data = [] while 1: last = (fragpos+FRAG_SIZE >= len(stream)) frag_data = buffer(stream, fragpos, FRAG_SIZE) x = len(frag_data) if last: x = x | 0x80000000L frag_head = struct.pack('>L', x) data.append(frag_head + str(frag_data)) if last: break fragpos += len(frag_data) return "".join(data) def parsed_record(self): """Return tupel (data, missing_bytes) of record""" fragpos = 0 data = [] while 1: frag = buffer(self.record, fragpos) fraghead = buffer(self.record, fragpos, FRAG_HEADER_LEN) data.append(frag[4:]) if len(frag) < FRAG_HEADER_LEN: return ("".join(data), FRAG_HEADER_LEN - len(frag)) len_from_head = self.frag_length(fraghead) if len(frag) < len_from_head: # Incomplete fragment return ("".join(data), len_from_head - len(frag)) elif len(frag) == len_from_head: # Complete fragment if self.frag_last(fraghead): # No need to read anything more return ("".join(data), 0) else: # Read another fragment return ("".join(data), FRAG_HEADER_LEN) elif len(frag) > len_from_head: # There are more fragments, check them fragpos += len(frag) else: assert 0 def read_record(self): """Read RPC record. Returns true if record complete""" self.assert_sock() assert self.state == STATE_READING bytes_to_read = self.parsed_record()[1] if bytes_to_read == 0: return 1 data = self.sock.recv(bytes_to_read) if data == "": self.eof_event() return 0 self.record += data return self.parsed_record()[1] == 0 class ServerConnection(RPCConnection): def __init__(self, host, port): RPCConnection.__init__(self) self.host = host self.port = port self.calls = [] # A list of ServerCalls self.state = STATE_WAITING self.current_cb = None def eof_event(self): """Overridden eof_event, which re-connects""" print >>sys.stderr, "Lost connection to server, trying to reconnect." # Discard the current call self.current_cb("") self.current_cb = None if self.calls: self.state = STATE_WRITING else: self.state = STATE_WAITING # Re-create socket self.sock.close() self.sock = None self.assert_sock() def assert_sock(self): """Overridden assert_sock, which connects dynamically""" if self.sock is None: srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) while 1: try: srv_sock.connect((self.host, self.port)) print >>sys.stderr, "Connected to %s:%d" % (self.host, self.port) break except socket.error, e: print >>sys.stderr, "Connection to %s:%d failed: %s" % (self.host, self.port, e[1]) time.sleep(5) self.set_sock(srv_sock) def call(self, servercall): """Put another call on the call queue. The call argument is a stream, without RMs. The callback will be called with result""" self.calls.append(servercall) if self.state == STATE_WAITING: self.state = STATE_WRITING def handle_read(self): """Called when socket is ready for read""" if self.read_record(): self.current_cb(self.parsed_record()[0]) self.current_cb = None if self.calls: self.state = STATE_WRITING else: self.state = STATE_WAITING def handle_write(self): """Called when socket is ready for write""" assert self.state == STATE_WRITING if self.current_cb is None: # Start working on another request servercall = self.calls.pop(0) self.record = self.rm_stream(servercall.data) self.current_cb = servercall.callback assert self.current_cb if self.write_record(): self.state = STATE_READING self.record = "" class ClientConnection(RPCConnection): def __init__(self, sock, addr, srv): RPCConnection.__init__(self) self.set_sock(sock) self.addr = addr self.srv = srv self.state = STATE_READING def handle_read(self): """Called when socket is ready for read""" if self.read_record(): self.state = STATE_WAITING self.srv.call(ServerCall(self.parsed_record()[0], self.got_response)) def handle_write(self): """Called when socket is ready for write""" assert self.state == STATE_WRITING if self.write_record(): self.state = STATE_READING self.record = "" def got_response(self, data): """Callback: We got a response from the server""" # send to client self.state = STATE_WRITING self.record = self.rm_stream(data) def usage(): sys.exit("Usage: %s [bind_address:]port:host:hostport ..." % sys.argv[0]) def parse_arg(arg): """Parse a command argument, specifying hosts and ports. Returns tuple (bind_address, port, host, hostport)""" fields = arg.split(":") if len(fields) == 3: fields.insert(0, "127.0.0.1") if len(fields) != 4: usage() bind_address, port, host, hostport = fields port = int(port) hostport = int(hostport) return bind_address, port, host, hostport def main(): if len(sys.argv) < 2: usage() pe = ProxyEngine() # # Determine hosts and ports # for arg in sys.argv[1:]: pe.add_proxy(*parse_arg(arg)) # # Select loop # while 1: # Set up sets read_set = [] read_set.extend(pe.proxies) write_set = [] for conn in pe.connections: if conn.readable(): read_set.append(conn) if conn.writable(): write_set.append(conn) rlist, wlist, xlist = select.select(read_set, write_set, []) for obj in rlist: obj.handle_read() for obj in wlist: obj.handle_write() for conn in pe.connections: if conn.eof(): pe.connections.remove(conn) if __name__ == "__main__": try: main() except KeyboardInterrupt: sys.exit(0)