from functools import partial

import matplotlib.animation as mplanim
import matplotlib.axes as maxes
import matplotlib.backend_bases as mback
import matplotlib.figure as mfigure
import matplotlib.pyplot as plt
import numpy as np
import pytest
from mpl_animators import base

from sunpy.tests.helpers import figure_test
from sunpy.visualization.animator import ArrayAnimator, BaseFuncAnimator, LineAnimator

pytestmark = pytest.mark.filterwarnings("ignore::sunpy.util.exceptions.SunpyDeprecationWarning")


class FuncAnimatorTest(BaseFuncAnimator):
    def plot_start_image(self, ax):
        im = ax.imshow(self.data[0])
        if self.if_colorbar:
            self._add_colorbar(im)
        return im


def update_plotval(val, im, slider, data):
    i = int(val)
    im.set_array(data[i])


def button_func1(*args, **kwargs):
    print(*args, **kwargs)


@pytest.mark.parametrize('fig, colorbar, buttons',
                         ((None, False, [[], []]),
                          (mfigure.Figure(), True, [[button_func1], ["hi"]])))
def test_base_func_init(fig, colorbar, buttons):
    data = np.random.random((3, 10, 10))
    func0 = partial(update_plotval, data=data)
    func1 = partial(update_plotval, data=data*10)
    funcs = [func0, func1]
    ranges = [(0, 3), (0, 3)]

    tfa = FuncAnimatorTest(data, funcs, ranges, fig=fig, colorbar=colorbar,
                           button_func=buttons[0],
                           button_labels=buttons[1])

    tfa.label_slider(0, "hello")
    assert tfa.sliders[0]._slider.label.get_text() == "hello"

    tfa._set_active_slider(1)
    assert tfa.active_slider == 1

    fig = tfa.fig
    event = mback.KeyEvent(name='key_press_event', canvas=fig.canvas, key='down')
    tfa._key_press(event)
    assert tfa.active_slider == 0

    event.key = 'up'
    tfa._key_press(event)
    assert tfa.active_slider == 1

    tfa.slider_buttons[tfa.active_slider]._button.clicked = False
    event.key = 'p'
    tfa._click_slider_button(event=event, button=tfa.slider_buttons[tfa.active_slider]._button,
                             slider=tfa.sliders[tfa.active_slider]._slider)
    assert tfa.slider_buttons[tfa.active_slider]._button.label._text == "||"

    tfa._key_press(event)
    assert tfa.slider_buttons[tfa.active_slider]._button.label._text == ">"

    event.key = 'left'
    tfa._key_press(event)
    assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmax

    event.key = 'right'
    tfa._key_press(event)
    assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmin

    event.key = 'right'
    tfa._key_press(event)
    assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmin + 1

    event.key = 'left'
    tfa._key_press(event)
    assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmin

    tfa._start_play(event, tfa.slider_buttons[tfa.active_slider]._button,
                    tfa.sliders[tfa.active_slider]._slider)
    assert tfa.timer

    tfa._stop_play(event)
    assert tfa.timer is None

    tfa._slider_changed(val=2, slider=tfa.sliders[tfa.active_slider]._slider)
    assert np.array(tfa.im.get_array()).all() == data[2].all()

    event.inaxes = tfa.sliders[0]
    tfa._mouse_click(event)
    assert tfa.active_slider == 0


# Make sure figures created directly and through pyplot work
@pytest.fixture(params=[plt.figure, mfigure.Figure])
def funcanimator(request):
    data = np.random.random((3, 10, 10))
    func = partial(update_plotval, data=data)
    funcs = [func]
    ranges = [(0, 3)]
    fig = request.param()

    return FuncAnimatorTest(data, funcs, ranges, fig=fig)


def test_to_anim(funcanimator):
    ani = funcanimator.get_animation()
    assert isinstance(ani, mplanim.FuncAnimation)


def test_to_axes(funcanimator):
    assert isinstance(funcanimator.axes, maxes.SubplotBase)


def test_axes_set():
    data = np.random.random((3, 10, 10))
    funcs = [partial(update_plotval, data=data)]
    ranges = [(0, 3)]

    # Create Figure for animator
    fig1 = plt.figure()
    # Create new Figure, Axes, and set current axes
    fig2, ax = plt.subplots()
    plt.sca(ax)
    ani = FuncAnimatorTest(data, funcs, ranges, fig=fig1)
    # Make sure the animator axes is now the current axes
    assert plt.gca() is ani.axes

    [plt.close(f) for f in [fig1, fig2]]


def test_edges_to_centers_nd():
    edges_axis = 0
    axis_range = np.zeros((10, 2))
    axis_range[:, 0] = np.arange(10, 20)
    expected = np.zeros((9, 2))
    expected[:, edges_axis] = np.arange(10.5, 19)
    output = base.edges_to_centers_nd(axis_range, edges_axis)
    assert np.array_equal(output, expected)


class ArrayAnimatorTest(ArrayAnimator):
    def __init__(self, data):
        self.naxis = data.ndim
        self.image_axes = [1]
        self.slider_axes = [0]

    def plot_start_image(self, ax):
        pass

    def update_plot(self, val, artist, slider):
        super().update_plot(val, artist, slider)


axis_ranges1 = np.tile(np.linspace(0, 100, 21), (10, 1))


@pytest.mark.parametrize('axis_ranges, exp_extent, exp_axis_ranges',
                         [([None, None], [-0.5, 19.5],
                           [np.arange(10), np.array([-0.5, 19.5])]),

                          ([[0, 10], [0, 20]], [0, 20],
                           [np.arange(0.5, 10.5), np.asarray([0, 20])]),

                          ([np.arange(0, 11), np.arange(0, 21)], [0, 20],
                           [np.arange(0.5, 10.5), np.arange(0.5, 20.5)]),

                          ([None, axis_ranges1], [0.0, 100.0],
                           [np.arange(10), base.edges_to_centers_nd(axis_ranges1, 1)])])
def test_sanitize_axis_ranges(axis_ranges, exp_extent, exp_axis_ranges):
    data_shape = (10, 20)
    data = np.random.rand(*data_shape)
    aanim = ArrayAnimatorTest(data=data)
    out_axis_ranges, out_extent = aanim._sanitize_axis_ranges(axis_ranges=axis_ranges,
                                                              data_shape=data_shape)
    assert exp_extent == out_extent
    assert np.array_equal(exp_axis_ranges[1], out_axis_ranges[1])
    assert callable(out_axis_ranges[0])
    assert np.array_equal(exp_axis_ranges[0], out_axis_ranges[0](np.arange(10)))


XDATA = np.tile(np.linspace(0, 100, 11), (5, 5, 1))


@pytest.mark.parametrize('plot_axis_index, axis_ranges, xlabel, xlim',
                         [(-1, None, None, None),
                          (-1, [None, None, XDATA], 'x-axis', None)])
def test_lineanimator_init(plot_axis_index, axis_ranges, xlabel, xlim):
    data = np.random.random((5, 5, 10))
    LineAnimator(data=data, plot_axis_index=plot_axis_index, axis_ranges=axis_ranges,
                 xlabel=xlabel, xlim=xlim)


def test_lineanimator_init_nans():
    data = np.random.random((5, 5, 10))
    data[0][0][:] = np.nan
    line_anim = LineAnimator(data=data, plot_axis_index=-1, axis_ranges=[None, None, XDATA],
                             xlabel='x-axis', xlim=None, ylim=None)
    assert line_anim.ylim[0] is not None
    assert line_anim.ylim[1] is not None
    assert line_anim.xlim[0] is not None
    assert line_anim.xlim[1] is not None


@figure_test
def test_lineanimator_figure():
    np.random.seed(1)
    data_shape0 = (10, 20)
    data0 = np.random.rand(*data_shape0)
    plot_axis0 = 1
    slider_axis0 = 0
    xdata = np.tile(np.linspace(
        0, 100, (data_shape0[plot_axis0] + 1)), (data_shape0[slider_axis0], 1))
    ani = LineAnimator(data0, plot_axis_index=plot_axis0, axis_ranges=[None, xdata])
    return ani.fig
