#!/usr/bin/python
#
# Copyright 2001 by Object Craft P/L, Melbourne, Australia.
#
# LICENCE - see LICENCE file distributed with this software for details.
#
import socket
import select
import random
import time
import binascii
import struct
import bisect
import sys
import getopt
import errno
import traceback

# Max consecutive server code failures a client can trigger before we
# close our connection to them.
MAX_CLIENT_EXCEPTIONS = 3


class Session:

    def __init__(self, sesid, life):
        self.sesid = sesid
        self.life = life
        self.text = ''
        self.last_access = time.time()

    def die_time(self):
        return self.last_access + self.life


class SessionDict:

    def __init__(self):
        self.dict = {}
        self.usage = 0

    def new_session(self, sesid, life):
        ses = Session(sesid, life)
        self.dict[sesid] = ses
        return ses

    def set_session(self, sesid, text):
        if sesid not in self.dict:
            return
        ses = self.dict[sesid]
        self.usage = self.usage - len(ses.text) + len(text)
        ses.text = text
        ses.last_access = time.time()

    def get_session(self, sesid, reset_access = 1):
        ses = self.dict.get(sesid)
        if ses and reset_access:
            ses.last_access = time.time()
        return ses

    def del_session(self, sesid):
        if sesid in self.dict:
            self.usage = self.usage = len(self.dict[sesid].text)
            del self.dict[sesid]

    def has_session(self, sesid):
        return sesid in self.dict


class SesLife:

    def __init__(self, sesid, die_time):
        self.sesid = sesid
        self.die_time = die_time

    def __cmp__(self, other):
        return cmp(self.die_time, other.die_time)


class Reaper:

    def __init__(self, sessions):
        self.sessions = sessions
        self.list = []

    def add(self, sesid, life):
        bisect.insort(self.list, SesLife(sesid, time.time() + life))

    def process(self):
        now = time.time()
        while self.list:
            event = self.list[0]
            if event.die_time > now:
                break
            del self.list[0]
            ses = self.sessions.get_session(event.sesid, 0)
            if ses:
                die_time = ses.die_time()
                if die_time <= now:
                    self.sessions.del_session(event.sesid)
                else:
                    event.die_time = die_time
                    bisect.insort(self.list, event)
        if self.list:
            event = self.list[0]
            return event.die_time - now


class Context:

    def __init__(self, log_file):
        self.log_file = log_file
        self.sessions = SessionDict()
        self.reaper = Reaper(self.sessions)

        self.read_files = []
        self.write_files = []

    def log(self, msg):
        if self.log_file:
            self.log_file.write(msg)
            self.log_file.flush()

    def add_write_file(self, file):
        self.write_files.append(file)

    def del_write_file(self, file):
        try:
            self.write_files.remove(file)
        except ValueError:
            pass

    def add_read_file(self, file):
        self.read_files.append(file)

    def del_read_file(self, file):
        try:
            self.read_files.remove(file)
        except ValueError:
            pass

    def del_file(self, file):
        self.del_read_file(file)
        self.del_write_file(file)

    def select(self):
        timeout = self.reaper.process()
        try:
            read_out, write_out, oob_out = \
                select.select(self.read_files, self.write_files, [], timeout)
        except select.error, (eno, estr):
            if eno in (errno.EAGAIN, errno.EINTR):
                return [], []
            raise
        return read_out, write_out

    def reaper_add(self, sesid, life):
        self.reaper.add(sesid, life)

    def new_session(self, sesid, life):
        return self.sessions.new_session(sesid, life)

    def del_session(self, sesid):
        return self.sessions.del_session(sesid)

    def get_session(self, sesid):
        ses = self.sessions.get_session(sesid)
        if ses:
            self.log("session get %s, %d bytes, %d total\n" % \
                (sesid, len(ses.text), self.sessions.usage))
        return ses

    def set_session(self, sesid, text):
        self.log("session put %s, %d bytes, %d total\n" % \
            (sesid, len(text), self.sessions.usage))
        return self.sessions.set_session(sesid, text)

    def has_session(self, sesid):
        return self.sessions.has_session(sesid)


COMMAND = 0
PUT = 1


class Client:

    def __init__(self, context, sock, addr):
        self.context = context

        self.sock = sock
        self.addr = addr
        self.sock.setblocking(0)
        self.input = ''
        self.output = []
        self.state = COMMAND
        self.error_count = 0
        self.context.log('new client: %s from %s\n' % (sock.fileno(), addr[0]))

    def __del__(self):
        self.context.log('del client: %s\n' % self.sock.fileno())

    def fileno(self):
        return self.sock.fileno()

    def do_read(self):
        try:
            str = self.sock.recv(16384)
        except socket.error:
            return 0
        if not str:
            return 0
        self.input = self.input + str
        while 1:
            if self.state == COMMAND:
                pos = self.input.find('\r\n')
                if pos < 0:
                    return 1
                command = self.input[:pos]
                self.input = self.input[pos + 2:]
                if not self.handle_command(command):
                    return 0
            elif self.state == PUT:
                pos = self.input.find('\r\n\r\n')
                if pos < 0:
                    return 1
                text = self.input[:pos]
                self.input = self.input[pos + 4:]
                self.context.set_session(self.key, text)
                self.write('OK\r\n')
                self.state = COMMAND

    def do_write(self):
        while self.output:
            str = self.output[0]
            try:
                num = self.sock.send(str)
            except socket.error:
                return 0
            if num == len(str):
                del self.output[0]
            else:
                self.output[0] = str[num:]
                return 1
        self.context.del_write_file(self)
        return 1

    def write(self, str):
        if not self.output:
            self.context.add_write_file(self)
        self.output.append(str)

    def make_sesid(self, app):
        while 1:
            try:
                text = open('/dev/urandom').read(8)
            except:
                text = struct.pack('d', random.random())
            sesid = binascii.hexlify(text)
            if not self.context.get_session((app, sesid)):
                return sesid

    def handle_command(self, str):
        self.context.log('client %s cmd: %s\n' % (self.sock.fileno(), str))
        try:
            words = str.split()
            if not words:
                return 1
            cmd = words[0].lower()
            if cmd == 'new' and len(words) == 3:
                # new app life\r\n -> sesid\r\n
                # generate a new session id
                app = words[1]
                try:
                    life = int(words[2])
                except:
                    self.write('ERROR bad life\r\n')
                    return 1
                sesid = self.make_sesid(words[1])
                self.write('OK %s\r\n' % sesid)
                key = (app, sesid)
                self.context.new_session(key, life)
                self.context.reaper_add(key, life)
            elif cmd == 'get' and len(words) == 3:
                # get app sesid\r\n -> text\r\n
                # get the data for a session
                app, sesid = words[1:]
                key = (app, sesid)
                ses = self.context.get_session(key)
                if not ses:
                    self.write('ERROR no such session\r\n')
                else:
                    self.write('OK - session follows\r\n')
                    self.write(ses.text)
                    self.write('\r\n\r\n')
            elif cmd == 'put' and len(words) == 3:
                # put app sesid\r\ntext\r\n -> OK\r\n
                # put session data
                app, sesid = words[1:]
                key = (app, sesid)
                if not self.context.has_session(key):
                    self.write('ERROR no such session\r\n')
                else:
                    self.key = key
                    self.state = PUT
                    self.write('OK - send data now, terminate with blank line\r\n')
            elif cmd == 'del' and len(words) == 3:
                # del app sesid\r\n -> OK\r\n
                # delete session
                app, sesid = words[1:]
                key = (app, sesid)
                ses = self.context.get_session(key)
                if not ses:
                    self.write('ERROR no such session\r\n')
                else:
                    self.context.del_session(key)
                    self.write('OK\r\n')
            elif cmd == 'quit' and len(words) == 1:
                return 0
            else:
                self.write('ERROR unrecognised command: %s\r\n' % str)
            if self.error_count > 0:
                self.error_count -= 1
            return 1
        except:
            self.error_count += 1
            if self.error_count > MAX_CLIENT_EXCEPTIONS:
                return 0
            traceback.print_exc()
            self.write('ERROR session server failure - check logs\r\n')
            return 1


class Server:

    def __init__(self, port, log_file):
        self.context = Context(log_file)

        self.listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.listen_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.listen_sock.setblocking(0)
        self.listen_sock.bind(('', port))
        self.listen_sock.listen(5)

        self.context.add_read_file(self.listen_sock)

    def select_loop(self):
        self.run = 1
        while self.run:
            read_out, write_out = self.context.select()
            for s in read_out:
                if s is self.listen_sock:
                    c, addr = self.listen_sock.accept()
                    self.context.add_read_file(Client(self.context, c, addr))
                    del c # Otherwise socket won't be GC'ed in a timely manner
                else:
                    if not s.do_read():
                        self.context.del_file(s)
            for s in write_out:
                if not s.do_write():
                    self.context.del_file(s)
            read_out = write_out = s = None

    def stop(self):
        # This allows signal handlers to ask us to stop
        self.run = 0


def usage():
    sys.stderr.write('usage: session-server.py [-h] [-p port] [-l log-file]\n')


if __name__ == '__main__':
    port = 34343
    log_file = None
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "hp:l:", ["help", "port=", "log="])
    except getopt.GetoptError:
        usage()
        sys.exit(2)
    for opt, arg in opts:
        if opt in ("-h", "--help"):
            usage()
            sys.exit()
        elif opt in ("-p", "--port"):
            port = int(arg)
        elif opt in ("-l", "--log"):
            log_file = open(arg, 'a')
    server = Server(port, log_file)
    server.select_loop()
