#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright (C) 2018 Takashi Sakamoto

import time
import signal
import sys

from hinawa_utils.misc.cui_kit import CuiKit

from hinawa_utils.bebob.maudio_unit import MaudioUnit

def _handle_target_volume(unit, args, cmd, targets_func, set_func, get_func):
    chs = (0, 1)
    ops = ('set', 'get')
    targets = targets_func()
    if len(args) >= 1 and args[0] in targets:
        target = args[0]
        if len(args) >= 2 and int(args[1]) in chs:
            ch = int(args[1])
            if len(args) >= 3 and args[2] in ops:
                op = args[2]
                if op == 'set' and len(args) >= 4:
                    db = float(args[3])
                    set_func(target, ch, db)
                    return True
                elif op == 'get':
                    print(get_func(target, ch))
                    return True
    print('Arguments for {0} command:'.format(cmd))
    print('  {0} TARGET CH OP [dB]'.format(cmd))
    print('    TARGET:   [{0}]'.format('|'.join(targets)))
    print('    CH:       [{0}]'.format('|'.join(chs)))
    print('    OP:       [{0}]'.format('|'.join(ops)))
    print('    dB:       [-128.00..128.00] if OP=set')
    return False

def handle_input_gain(unit, args):
    return _handle_target_volume(unit, args, 'input-gain',
                                 unit.protocol.get_input_labels,
                                 unit.protocol.set_input_gain,
                                 unit.protocol.get_input_gain)

def handle_aux_input(unit, args):
    return _handle_target_volume(unit, args, 'aux-input',
                                 unit.protocol.get_aux_input_labels,
                                 unit.protocol.set_aux_input,
                                 unit.protocol.get_aux_input)

def handle_output_volume(unit, args):
    return _handle_target_volume(unit, args, 'output-volume',
                                 unit.protocol.get_output_labels,
                                 unit.protocol.set_output_volume,
                                 unit.protocol.get_output_volume)

def handle_headphone_volume(unit, args):
    return _handle_target_volume(unit, args, 'headphone-volume',
                                 unit.protocol.get_headphone_labels,
                                 unit.protocol.set_headphone_volume,
                                 unit.protocol.get_headphone_volume)

def handle_input_balance(unit, args):
    chs = (0, 1)
    ops = ('set', 'get')
    targets = unit.protocol.get_input_balance_labels()
    if len(args) >= 1 and args[0] in targets:
        target = args[0]
        if len(args) >= 2 and int(args[1]) in chs:
            ch = int(args[1])
            if len(args) >= 3 and args[2] in ops:
                op = args[2]
                if len(args) >= 4 and op == 'set':
                    balance = float(args[3])
                    unit.protocol.set_input_balance(target, ch, balance)
                    return True
                elif op == 'get':
                    print(unit.protocol.get_input_balance(target, ch))
                    return True
    print('Arguments for input-balance command')
    print('  input-balance TARGET CH OP [BALANCE]')
    print('    TARGET: [{0}]'.format('|'.join(targets)))
    print('    CH:     [{0}]'.format('|'.join(chs)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    BALANCE:[-128.0..128.0] (left-right) if OP=set')
    return False

def handle_aux_balance(unit, args):
    chs = (0, 1)
    ops = ('set', 'get')
    if len(args) >= 1 and int(args[0]) in chs:
        ch = int(args[0])
        if len(args) >= 2 and args[1] in ops:
            op = args[1]
            if len(args) >= 3 and op == 'set':
                balance = float(args[2])
                unit.protocol.set_aux_balance(ch, balance)
                return True
            else:
                print(unit.protocol.get_aux_balance(ch))
                return True
    print('Arguments for aux-balance command:')
    print('  aux-balance CH OP [BALANCE]')
    print('    CH:     [{0}]'.format('|'.join(chs)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    BALACE: [-128.0..128.0] (left-right) if OP=set]')
    return False

def handle_aux_volume(unit, args):
    chs = (0, 1)
    ops = ('set', 'get')
    if len(args) > 0 and int(args[0]) in chs:
        ch = int(args[0])
        if len(args) > 1 and args[1] in ops:
            op = args[1]
            if len(args) > 2 and op == 'set':
                db = float(args[2])
                unit.protocol.set_aux_volume(ch, db)
                return True
            elif op == 'get':
                print(unit.protocol.get_aux_volume(ch))
                return True
    print('Arguments for aux-volume command:')
    print('  aux-volume CH OP [dB]')
    print('    CH:     [{0}]'.format('|'.join(chs)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    dB:     [-128.0..128.0] if OP=set')
    return False

def handle_mixer_routing(unit, args):
    ops = ('set', 'get')
    targets = unit.protocol.get_mixer_labels()
    sources = unit.protocol.get_mixer_source_labels()
    if len(args) >= 1 and args[0] in targets:
        target = args[0]
        if len(args) >= 2 and args[1] in sources:
            source = args[1]
            if len(args) >= 3 and args[2] in ops:
                op = args[2]
                if len(args) >= 4 and op == 'set':
                    enable = bool(int(args[3]))
                    unit.protocol.set_mixer_routing(target, source, enable)
                    return True
                elif op == 'get':
                    print(unit.protocol.get_mixer_routing(target, source))
                    return True
    print('Arguments for mixer-routing command:')
    print('  mixer-routing TARGET SOURCE [ENABLE]')
    print('    TARGET: [{0}]'.format('|'.join(targets)))
    print('    SOURCE: [{0}]'.format('|'.join(sources)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    ENABLE: [0|1]')
    return False

def _handle_target_source(unit, args, cmd, targets_func, sources_func,
                          set_func, get_func):
    ops = ('set', 'get')
    targets = targets_func()
    if len(args) >= 1 and args[0] in targets:
        target = args[0]
        if len(args) >= 2 and args[1] in ops:
            op = args[1]
            sources = sources_func(target)
            if len(args) >= 3 and op == 'set' and args[2] in sources:
                source = args[2]
                set_func(target, source)
                return True
            elif op == 'get':
                print(get_func(target))
                return True
    print('Arguments for {0} command:'.format(cmd))
    print('  {0} TARGET OP [SRC]'.format(cmd))
    print('    TARGET:    [{0}]'.format('|'.join(targets)))
    print('    OP:        [{0}]'.format('|'.join(ops)))
    for target in targets:
        sources = sources_func(target)
        print('    SRC:       [{0}] if TARGET={1} and OP=set'.format(
                                                '|'.join(sources), target))
    return False

def handle_headphone_source(unit, args):
    return _handle_target_source(unit, args, 'headphone-source',
                                 unit.protocol.get_headphone_labels,
                                 unit.protocol.get_headphone_source_labels,
                                 unit.protocol.set_headphone_source,
                                 unit.protocol.get_headphone_source)

def handle_output_source(unit, args):
    return _handle_target_source(unit, args,  'output-source',
                                 unit.protocol.get_output_labels,
                                 unit.protocol.get_output_source_labels,
                                 unit.protocol.set_output_source,
                                 unit.protocol.get_output_source)

def handle_clock_source(unit, args):
    ops = ('set', 'get')
    sources = unit.protocol.get_clock_source_labels()
    if len(args) > 0 and args[0] in ops:
        op = args[0]
        if len(args) > 1 and op == 'set' and args[1] in sources:
            source = args[1]
            unit.protocol.set_clock_source(source)
            return True
        elif op == 'get':
            print(unit.protocol.get_clock_source())
            return True
    print('Arguments for clock-source command:')
    print('  clock-source OP [SRC]')
    print('    OP:     {0}'.format('|'.join(ops)))
    print('    SRC:    {0} if OP=set'.format('|'.join(sources)))
    print('  Packet streaming should be stopped.')
    return False

def handle_listen_metering(unit, args):
    # This is handled by another context.
    def handle_unix_signal(signum, frame):
        sys.exit()
    signal.signal(signal.SIGINT, handle_unix_signal)
    while 1:
        # At higher sampling rate, this causes timeout frequently.
        try:
            meters = unit.protocol.get_meters()
            for name in sorted(meters):
                print('{0}: {1:08x}'.format(name, meters[name]))
        except Exception as e:
            pass
        print('')
        time.sleep(0.1)
    return True

def handle_sampling_rate(unit, args):
    ops = ('set', 'get')
    rates = unit.protocol.get_sampling_rate_labels()
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        if len(args) >= 2 and op == 'set':
            rate = int(args[1])
            unit.protocol.set_sampling_rate(rate)
            return True
        elif op == 'get':
            print(unit.protocol.get_sampling_rate())
            return True
    print('Arguments for sampling-rate command:')
    print('  sampling-rate OP [RATE]')
    print('    OP:     {0}'.format('|'.join(ops)))
    print('    RATE:   {0}'.format('|'.join(map(str, rates))))
    return False

cmds = {
    'input-gain':           handle_input_gain,
    'input-balance':        handle_input_balance,
    'mixer-routing':        handle_mixer_routing,
    'listen-metering':      handle_listen_metering,
    'clock-source':         handle_clock_source,
    'sampling-rate':        handle_sampling_rate,
}

fullpath = CuiKit.seek_snd_unit_path()
if fullpath:
    unit = MaudioUnit(fullpath)
    if len(unit.protocol.get_aux_input_labels()) > 0:
        cmds['aux-input'] = handle_aux_input
        cmds['aux-volume'] = handle_aux_volume
        if (hasattr(unit.protocol, 'set_aux_balance') and
            hasattr(unit.protocol, 'get_aux_balance')):
            cmds['aux-balance'] = handle_aux_balance
    if len(unit.protocol.get_output_labels()) > 0:
        cmds['output-volume'] = handle_output_volume
        cmds['output-source'] = handle_output_source
    if len(unit.protocol.get_headphone_labels()) > 0:
        cmds['headphone-volume'] = handle_headphone_volume
        cmds['headphone-source'] = handle_headphone_source
    CuiKit.dispatch_command(unit, cmds)
