import mock
import koji

from koji.xmlrpcplus import Fault


class BaseFakeClientSession(koji.ClientSession):

    def __init__(self, *a, **kw):
        super(BaseFakeClientSession, self).__init__(*a, **kw)

    def multiCall(self, strict=False):
        if not self.multicall:
            raise Exception("not in multicall")
        ret = []
        self.multicall = False
        calls = self._calls
        self._calls = []
        for call in calls:
            method = call['methodName']
            args, kwargs = koji.decode_args(call['params'])
            try:
                result = self._callMethod(method, args, kwargs)
                ret.append(result)
            except Fault as fault:
                if strict:
                    raise
                else:
                    ret.append({'faultCode': fault.faultCode,
                                'faultString': fault.faultString})
        return ret


class FakeClientSession(BaseFakeClientSession):

    def __init__(self, *a, **kw):
        super(FakeClientSession, self).__init__(*a, **kw)
        self._calldata = {}
        self._offsets = {}

    def load_calls(self, data):
        """Load call data

        Data should be a list of dictionaries with keys:
            - method
            - args
            - kwargs
            - result  (for successful calls)
            - fault   (for errors)
        That represent call data, e.g. as generated by RecordingClientSession
        """

        for call in data:
            key = self._munge([call['method'], call['args'], call['kwargs']])
            self._calldata.setdefault(key, []).append(call)

    def _callMethod(self, name, args, kwargs=None, retry=True):
        if self.multicall:
            return super(FakeClientSession, self)._callMethod(name, args,
                            kwargs, retry)
        key = self._munge([name, args, kwargs])
        # we may have a series of calls for each key
        calls = self._calldata.get(key)
        ofs = self._offsets.get(key, 0)
        call = calls[ofs]
        ofs += 1
        if ofs < len(calls):
            # don't go past the end
            self._offsets[key] = ofs
        if call:
            if 'fault' in call:
                fault = Fault(call['fault']['faultCode'],
                              call['fault']['faultString'])
                raise koji.convertFault(fault)
            else:
                return call['result']
        else:
            return mock.MagicMock()

    def _munge(self, data):
        def callback(value):
            if isinstance(value, list):
                return tuple(value)
            elif isinstance(value, dict):
                keys = sorted(value.keys())
                return tuple([(k, value[k]) for k in keys])
            else:
                return value
        walker = koji.util.DataWalker(data, callback)
        return walker.walk()


class RecordingClientSession(BaseFakeClientSession):

    def __init__(self, *a, **kw):
        super(RecordingClientSession, self).__init__(*a, **kw)
        self._calldata = []

    def get_calls(self):
        return self._calldata

    def _callMethod(self, name, args, kwargs=None, retry=True):
        if self.multicall:
            return super(RecordingClientSession, self)._callMethod(name, args,
                            kwargs, retry)
        call = {
                'method': name,
                'args': args,
                'kwargs': kwargs,
                }
        self._calldata.append(call)
        try:
            result = super(RecordingClientSession, self)._callMethod(name, args,
                            kwargs, retry)
            call['result'] = result
            return result
        except Fault as fault:
            err = {'faultCode': fault.faultCode,
                   'faultString': fault.faultString}
            call['fault'] = err
            raise
        except koji.GenericError as e:
            err = {'faultCode': e.faultCode,
                   'faultString': str(e)}
            call['fault'] = err
            raise
