Sophie

Sophie

distrib > Mageia > 5 > i586 > media > core-release > by-pkgid > f03dcf88901fda23f88935a60a2fd973 > files > 20

unfs3-0.9.22-7.mga5.i586.rpm

#!/usr/bin/env python
# -*-mode: python; coding: utf-8 -*-

# TODO:
# Support for limiting data sizes, max number of connections from the same IP etc

import sys
import time
import socket
import select
import struct

# Connection states, both for client and server connection.
# Client cycle: STATE_READING, WAITING, WRITING
# Server cycle: WAITING, WRITING, STATE_READING
STATE_READING = 0 # Reading record
STATE_WAITING = 2 # Waiting for server response callback, or client mission. 
STATE_WRITING = 3 # Writing response to client or request to server
STATE_EOF = 4 # EOF while reading

# Constants
FRAG_HEADER_LEN = 4
FRAG_MAX_SIZE = 2**31 - 1
FRAG_SIZE = FRAG_MAX_SIZE # Size of newly created fragments


class ProxyEngine:
    def __init__(self):
        self.connections = [] # Client or server connections
        self.proxies = [] # Proxy objects


    def add_proxy(self, bind_address, port, host, hostport):
        """Add a new proxy"""
        proxy = Proxy(self, bind_address, port, host, hostport)
        self.proxies.append(proxy)
        self.connections.append(proxy.srv)

    def add_connection(self, conn):
        """Add a new connection"""
        self.connections.append(conn)


class Proxy:
    def __init__(self, pe, bind_address, port, host, hostport):
        self.pe = pe
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.sock.bind((bind_address, port)) 
        self.sock.listen(1)
        self.srv = ServerConnection(host, hostport)


    def fileno(self):
        """Return the sockets fileno"""
        return self.sock.fileno()


    def handle_read(self):
        """Accept a new connection. Return a new ClientConnection"""
        sock, addr = self.sock.accept()
        self.pe.add_connection(ClientConnection(sock, addr, self.srv))


class ServerCall:
    def __init__(self, data, callback):
        self.data = data
        self.callback = callback


class RPCConnection:
    def __init__(self):
        self.record = "" # Current record, as stream with RMs
        self.sndbuf = None
        self.sock = None


    def set_sock(self, sock):
        """Set socket to use"""
        self.sock = sock
        self.sndbuf = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)


    def eof_event(self):
        """recv/send discovered that the connection was closed"""
        self.state = STATE_EOF 
        self.sock.close()


    def assert_sock(self):
        """Make sure a socket is available. May be overridden."""
        assert self.sock is not None


    def readable(self):
        """Returns true if connection wants to read"""
        return self.state is STATE_READING


    def writable(self):
        """Returns true if connection wants to write"""
        return self.state is STATE_WRITING


    def eof(self):
        """Returns true if EOF has been detected"""
        return self.state is STATE_EOF


    def fileno(self):
        """Return the sockets fileno"""
        self.assert_sock()
        return self.sock.fileno()


    def write_record(self):
        """Write RPC record. Returns true when everything is written"""
        self.assert_sock()
        # We can write up to SO_SNDBUF without risk blocking
        wrote = self.sock.send(buffer(self.record, 0, self.sndbuf))
        self.record = self.record[wrote:]
        return len(self.record) == 0


    def frag_length(self, head):
        """Return the length of a fragment, including header"""
        assert len(head) == FRAG_HEADER_LEN
        x = struct.unpack('>L', head)[0]
        return int(x & 0x7fffffff) + FRAG_HEADER_LEN


    def frag_last(self, head):
        """Return true if last flag is set"""
        assert len(head) == FRAG_HEADER_LEN
        x = struct.unpack('>L', head)[0]
        return ((x & 0x80000000L) != 0)


    def rm_stream(self, stream):
        """Record-mark a data stream"""
        fragpos = 0
        data = []

        while 1:
            last = (fragpos+FRAG_SIZE >= len(stream))
            frag_data = buffer(stream, fragpos, FRAG_SIZE)
            x = len(frag_data)
            if last:
                x = x | 0x80000000L
            frag_head = struct.pack('>L', x)
            data.append(frag_head + str(frag_data))
            if last:
                break
            fragpos += len(frag_data)

        return "".join(data)


    def parsed_record(self):
        """Return tupel (data, missing_bytes) of record"""
        fragpos = 0
        data = []
        while 1:
            frag = buffer(self.record, fragpos)
            fraghead = buffer(self.record, fragpos, FRAG_HEADER_LEN)
            data.append(frag[4:])

            if len(frag) < FRAG_HEADER_LEN:
                return ("".join(data), FRAG_HEADER_LEN - len(frag))
            
            len_from_head = self.frag_length(fraghead)
            if len(frag) < len_from_head:
                # Incomplete fragment
                return ("".join(data), len_from_head - len(frag))
            elif len(frag) == len_from_head:
                # Complete fragment
                if self.frag_last(fraghead):
                    # No need to read anything more
                    return ("".join(data), 0)
                else:
                    # Read another fragment
                    return ("".join(data), FRAG_HEADER_LEN)
            elif len(frag) > len_from_head:
                # There are more fragments, check them
                fragpos += len(frag)
            else:
                assert 0


    def read_record(self):
        """Read RPC record. Returns true if record complete"""
        self.assert_sock()
        assert self.state == STATE_READING
        bytes_to_read = self.parsed_record()[1]
        if bytes_to_read == 0:
            return 1
        
        data = self.sock.recv(bytes_to_read)

        if data == "":
            self.eof_event()
            return 0

        self.record += data
        return self.parsed_record()[1] == 0


class ServerConnection(RPCConnection):
    def __init__(self, host, port):
        RPCConnection.__init__(self)
        self.host = host
        self.port = port
        self.calls = [] # A list of ServerCalls
        self.state = STATE_WAITING
        self.current_cb = None


    def eof_event(self):
        """Overridden eof_event, which re-connects"""
        print >>sys.stderr, "Lost connection to server, trying to reconnect."
        # Discard the current call
        self.current_cb("")
        self.current_cb = None
        if self.calls:
            self.state = STATE_WRITING
        else:
            self.state = STATE_WAITING

        # Re-create socket
        self.sock.close()
        self.sock = None
        self.assert_sock()


    def assert_sock(self):
        """Overridden assert_sock, which connects dynamically"""
        if self.sock is None:
            srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            while 1:
                try:
                    srv_sock.connect((self.host, self.port))
                    print >>sys.stderr, "Connected to %s:%d" % (self.host, self.port)
                    break
                except socket.error, e:
                    print >>sys.stderr, "Connection to %s:%d failed: %s" % (self.host, self.port, e[1])
                    time.sleep(5)
            self.set_sock(srv_sock)


    def call(self, servercall):
        """Put another call on the call queue. The call argument is a
        stream, without RMs. The callback will be called with result"""
        self.calls.append(servercall)
        if self.state == STATE_WAITING:
            self.state = STATE_WRITING


    def handle_read(self):
        """Called when socket is ready for read"""
        if self.read_record():
            self.current_cb(self.parsed_record()[0])
            self.current_cb = None
            if self.calls:
                self.state = STATE_WRITING
            else:
                self.state = STATE_WAITING


    def handle_write(self):
        """Called when socket is ready for write"""
        assert self.state == STATE_WRITING
        if self.current_cb is None:
            # Start working on another request
            servercall = self.calls.pop(0)
            self.record = self.rm_stream(servercall.data)
            self.current_cb = servercall.callback

        assert self.current_cb
        if self.write_record():
            self.state = STATE_READING
            self.record = ""


class ClientConnection(RPCConnection):
    def __init__(self, sock, addr, srv):
        RPCConnection.__init__(self)
        self.set_sock(sock)
        self.addr = addr
        self.srv = srv
        self.state = STATE_READING

        
    def handle_read(self):
        """Called when socket is ready for read"""
        if self.read_record():
            self.state = STATE_WAITING
            self.srv.call(ServerCall(self.parsed_record()[0], self.got_response))


    def handle_write(self):
        """Called when socket is ready for write"""
        assert self.state == STATE_WRITING
        if self.write_record():
            self.state = STATE_READING
            self.record = ""


    def got_response(self, data):
        """Callback: We got a response from the server"""
        # send to client
        self.state = STATE_WRITING
        self.record = self.rm_stream(data)


def usage():
    sys.exit("Usage: %s [bind_address:]port:host:hostport ..." % sys.argv[0])


def parse_arg(arg):
    """Parse a command argument, specifying hosts and ports.
    Returns tuple (bind_address, port, host, hostport)"""
    fields = arg.split(":")
    if len(fields) == 3:
        fields.insert(0, "127.0.0.1")

    if len(fields) != 4:
        usage()

    bind_address, port, host, hostport = fields
    port = int(port)
    hostport = int(hostport)
    return bind_address, port, host, hostport


def main():
    if len(sys.argv) < 2:
        usage()

    pe = ProxyEngine()

    #
    # Determine hosts and ports
    #
    for arg in sys.argv[1:]:
        pe.add_proxy(*parse_arg(arg))

    #
    # Select loop
    #
    while 1:
        # Set up sets
        read_set = []
        read_set.extend(pe.proxies)
        write_set = []
        for conn in pe.connections:
            if conn.readable():
                read_set.append(conn)
            if conn.writable():
                write_set.append(conn)

        rlist, wlist, xlist = select.select(read_set, write_set, [])

        for obj in rlist:
            obj.handle_read()

        for obj in wlist:
            obj.handle_write()

        for conn in pe.connections:
            if conn.eof():
                pe.connections.remove(conn)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        sys.exit(0)