/* Ergo, version 3.8, a program for linear scaling electronic structure
 * calculations.
 * Copyright (C) 2019 Elias Rudberg, Emanuel H. Rubensson, Pawel Salek,
 * and Anastasia Kruchinina.
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Primary academic reference:
 * Ergo: An open-source program for linear-scaling electronic structure
 * calculations,
 * Elias Rudberg, Emanuel H. Rubensson, Pawel Salek, and Anastasia
 * Kruchinina,
 * SoftwareX 7, 107 (2018),
 * <http://dx.doi.org/10.1016/j.softx.2018.03.005>
 * 
 * For further information about Ergo, see <http://www.ergoscf.org>.
 */

/** @file recexp_many_tests.cc

    @brief  Test serial recursive expansion on a random symmetric matrix or
            matrix from a given binary file. Matrix in a binary file should contain only the upper triangle. Note: to get homo-lumo gap all matrix eigenvalues are computed.

    @author Anastasia Kruchinina <em>responsible</em>
 */

#ifndef USE_CHUNKS_AND_TASKS

#include "purification_sp2.h"
#include "purification_sp2acc.h"
#include "matrix_typedefs.h" // definitions of matrix types and interval type (source)
#include "realtype.h"   // definitions of types (utilities_basic)
#include "matrix_utilities.h"
#include "integral_matrix_wrappers.h"
#include "SizesAndBlocks.h"
#include "Matrix.h"
#include "Vector.h"
#include "MatrixSymmetric.h"
#include "MatrixTriangular.h"
#include "MatrixGeneral.h"
#include "VectorGeneral.h"
#include "output.h"

#include <iostream>
#include <fstream>
#include <sstream>
#include <string.h>

#include "random_matrices.h"
#include "get_eigenvectors.h"

typedef ergo_real real;
typedef symmMatrix MatrixType;
typedef MatrixType::VectorType VectorType;

#define SQRT_EPSILON_REAL    template_blas_sqrt(mat::getMachineEpsilon<real>())

real TOL_ERR_SUBS_DEFAULT = SQRT_EPSILON_REAL;
real TOL_TRACE_ERROR_DEFAULT = SQRT_EPSILON_REAL;

#ifdef PRECISION_SINGLE
real TOL_EIGENSOLVER_ACC_DEFAULT = 1e-6;
#elif PRECISION_DOUBLE
real TOL_EIGENSOLVER_ACC_DEFAULT = 1e-12;
#elif PRECISION_LONG_DOUBLE
real TOL_EIGENSOLVER_ACC_DEFAULT = 1e-16;
#elif PRECISION_QUAD_FLT128
real TOL_EIGENSOLVER_ACC_DEFAULT = 1e-24;
#endif


void plot_results(const Purification_sp2<MatrixType> & Puri)
{
   // plot results
   Puri.gen_matlab_file_norm_diff("puri_out_error.m");
   Puri.gen_matlab_file_threshold("puri_out_threshold.m");
   Puri.gen_matlab_file_nnz("puri_out_nnz.m");
   Puri.gen_matlab_file_eigs("puri_out_eigs.m");
   Puri.gen_matlab_file_time("puri_out_time.m");
   std::cout << "Created .m files with results of the purification" << std::endl;  
}

/*
 * This function is used to set structure of the matrix
 * (see function resetSizesAndBlocks is the matrix library).
 */
template<typename Matrix>
void init_matrix(Matrix& X, const int N)
{
        /********** Initialization of SizesAndBlocks */
        int size    = N;
        int nlevels = 5; //!!!

        std::vector<int> blockSizes(nlevels);
        blockSizes[nlevels - 1] = 1; // should always be one
        for (int ind = nlevels - 2; ind >= 0; ind--)
        {
                blockSizes[ind] = blockSizes[ind + 1] * 2;
        }
        mat::SizesAndBlocks rows(blockSizes, size);
        mat::SizesAndBlocks cols(blockSizes, size);
        /********************************************/
        X.resetSizesAndBlocks(rows, cols);
}


typedef struct data_for_recexp
{
        int N;
        int N_occ;
        real err_sub;
        real err_eig;
        mat::normType normPuri;
        mat::normType normPuriStopCrit;
        real homo_in;
        real homo_out;
        real lumo_in;
        real lumo_out;
        int maxit;

        bool homo_lumo_bounds_known;
        bool compute_eigenvectors;
        int which_stop_crit;

} data_for_recexp_t;


#define TEST_COUNT 8

void prepare_data_for_recexp(int testnum, data_for_recexp_t &DATA);
void set_test_1_data(data_for_recexp_t &DATA);
void set_test_2_data(data_for_recexp_t &DATA);
void set_test_3_data(data_for_recexp_t &DATA);
void set_test_4_data(data_for_recexp_t &DATA);
void set_test_5_data(data_for_recexp_t &DATA);
void set_test_6_data(data_for_recexp_t &DATA);
void set_test_7_data(data_for_recexp_t &DATA);
void set_test_8_data(data_for_recexp_t &DATA);
//void set_test_9_data(data_for_recexp_t &DATA);


int main(int argc, char *argv[])
{
        printf("Program performing the recursive expansion on a given matrix.\n");
        printf("Written by Anastasia Kruchinina, Feb 2019\n");
        printf("\n");

        #ifdef _OPENMP
        int defThreads;
        const char *env = getenv("OMP_NUM_THREADS");
        if ( !(env && (defThreads=atoi(env)) > 0) ) {
                defThreads = 1;
        }

        mat::Params::setNProcs(defThreads);
        mat::Params::setMatrixParallelLevel(2);
        std::cout<<"OpenMP is used, number of threads set to "
                 <<mat::Params::getNProcs()<<". Matrix parallel level: "
                 <<mat::Params::getMatrixParallelLevel()<<"."<<std::endl;
  #endif

        // set seed for srand
        // the seed it chosen such that generated matrices have enough large gaps and purifications converge
        int SEED = 1000;
        int num_iter_sp2 = -1;
        //enable_printf_output(); // write more debug info about each iteration
        bool puri_print_collected_info = false;

        for (int TESTNUM = 1; TESTNUM <= TEST_COUNT; TESTNUM++)
        {
                /* We want to create the same matrix for some tests. For example, we want to check the convergence when eigenvalue bounds are known and when not known. */
                srand (SEED);


                printf("TEST NUMBER %d\n", TESTNUM);
                data_for_recexp_t DATA;
                printf("Getting data...\n");
                prepare_data_for_recexp(TESTNUM, DATA);

                MatrixType F;
                int blockSizesMultuple = 4; // set matrix structure
                get_random_symm_matrix(DATA.N, F, blockSizesMultuple);
                printf("Created random symmetric matrix F.\n");
                //print_ergo_matrix(F);


                if(DATA.homo_lumo_bounds_known)
                {
                        // Get all eigenvalues of F. We need this to get bounds for homo and lumo for F.
                        std::vector<ergo_real> eigvalList;
                        get_all_eigenvalues_of_matrix(eigvalList, F);
                        ergo_real homo = eigvalList[DATA.N_occ-1];
                        ergo_real lumo = eigvalList[DATA.N_occ  ];
                        ergo_real epsilon_for_homo_lumo_intervals = 1e-4;
                        DATA.homo_out = homo-epsilon_for_homo_lumo_intervals;
                        DATA.homo_in  = homo+epsilon_for_homo_lumo_intervals;
                        DATA.lumo_in  = lumo-epsilon_for_homo_lumo_intervals;
                        DATA.lumo_out = lumo+epsilon_for_homo_lumo_intervals;

                        printf("homo bounds: [%lf, %lf]\n", (double)DATA.homo_out, (double)DATA.homo_in);
                        printf("lumo bounds: [%lf, %lf]\n", (double)DATA.lumo_in, (double)DATA.lumo_out);
                }
                else
                        printf("Homo and lumo bounds are not known.\n");

                // SET HOMO AND LUMO BOUNDS FOR F
                IntervalType homo_bounds = IntervalType(DATA.homo_out, DATA.homo_in);
                IntervalType lumo_bounds = IntervalType(DATA.lumo_in, DATA.lumo_out);

                if( homo_bounds.empty() )
                {
                        printf("Interval homo_bounds is empty.\n");
                        return EXIT_FAILURE;
                }
                if ( lumo_bounds.empty() )
                {
                        printf("Interval lumo_bounds is empty.\n");
                        return EXIT_FAILURE;
                }

                printf("\n");

                /***** START OF SP2 TEST ******/
                {
                        Purification_sp2<MatrixType> Puri;

                        Puri.initialize(F,
                                        lumo_bounds,
                                        homo_bounds,
                                        DATA.maxit,
                                        DATA.err_sub,
                                        DATA.err_eig,
                                        DATA.which_stop_crit, // 1 = new, 0 = old stopping criterion
                                        DATA.normPuri,
                                        DATA.normPuriStopCrit,
                                        DATA.N_occ);

                        // RUN RECURSIVE EXPANSION
                        printf("Start SP2 recursive expansion...\n");
                        Puri.PurificationStart();

                        if(puri_print_collected_info)
                                Puri.info.print_collected_info_printf();

                        // CHECK RESULT OF THE RECURSIVE EXPANSION
                        if (Puri.info.converged != 1)
                        {
                                throw std::runtime_error("SP2 did not converge!");
                        }
                        else
                                printf("SP2 converged in %d iterations.\n", Puri.info.total_it);

                        MatrixType X(Puri.X);
                        ergo_real traceX = X.trace();
                        if (template_blas_fabs(traceX - DATA.N_occ) > TOL_TRACE_ERROR_DEFAULT)
                        {
                                throw std::runtime_error("SP2: Wrong value of trace! (abs(traceX - N_occ) > TOL_TRACE_ERROR_DEFAULT)");
                        }

                        num_iter_sp2 = Puri.info.total_it;

                        if(Puri.info.accumulated_error_subspace >= 0)
                        {
                                printf("Maximum allowed error in subspace: %g\n"
                                       "Accumulated error in subspace: %g\n",
                                       (double)Puri.info.error_subspace, (double)Puri.info.accumulated_error_subspace );
                                assert(Puri.info.error_subspace >= Puri.info.accumulated_error_subspace);
                        }

                        // try to plot some staff
                        //plot_results(Puri);

                        printf("\n");
                }

                /*******  END OF SP2 TEST  ************/


                /***** START OF SP2ACC TEST ******/
                {
                        Purification_sp2acc<MatrixType> Puri;

                        Puri.initialize(F,
                                        lumo_bounds,
                                        homo_bounds,
                                        DATA.maxit,
                                        DATA.err_sub,
                                        DATA.err_eig,
                                        DATA.which_stop_crit, // 1 = new, 0 = old stopping criterion
                                        DATA.normPuri,
                                        DATA.normPuriStopCrit,
                                        DATA.N_occ);

                        // RUN RECURSIVE EXPANSION
                        printf("Start SP2ACC recursive expansion...\n");
                        Puri.PurificationStart();

                        if(puri_print_collected_info)
                                Puri.info.print_collected_info_printf();

                        // CHECK RESULT OF THE RECURSIVE EXPANSION
                        if (Puri.info.converged != 1)
                        {
                                throw std::runtime_error("SP2ACC did not converge!");
                        }
                        else
                                printf("SP2ACC converged in %d iterations.\n", Puri.info.total_it);

                        MatrixType X(Puri.X);
                        ergo_real traceX = X.trace();
                        if (template_blas_fabs(traceX - DATA.N_occ) > TOL_TRACE_ERROR_DEFAULT)
                        {
                                throw std::runtime_error("SP2ACC: Wrong value of trace! (abs(traceX - N_occ) > TOL_TRACE_ERROR_DEFAULT)");
                        }

                      /*
                      If homo and lumo eigenvalues are unknown, then sp2 and sp2acc are equivalent. We are using Lanczos algorithm to improve Gershgorin spectrum bounds. Since starting guesses are random, Lanczos may converge to slightly different values from run to run. Thus it may happen that the number of iterations is sp2 is smaller the in sp2acc. 
                      In this check we allow to have 3 more iterations in sp2acc. That should be enough.
                      */
                        if(num_iter_sp2 + 3 < Puri.info.total_it)
                        {
                                throw std::runtime_error("SP2ACC: number of iterations required for the SP2ACC convergence is less than the number of iterations required for the SP2 convergence\n");
                        }

                        if(Puri.info.accumulated_error_subspace >= 0)
                        {
                                printf("Maximum allowed error in subspace: %g\n"
                                       "Accumulated error in subspace: %g\n",
                                       (double)Puri.info.error_subspace, (double)Puri.info.accumulated_error_subspace );
                                assert(Puri.info.error_subspace >= Puri.info.accumulated_error_subspace);
                        }

                        printf("\n");
                }

                /*******  END OF SP2ACC TEST  ************/



                F.clear();

        }

        printf("DONE!\n");


        return EXIT_SUCCESS;
}



void prepare_data_for_recexp(int testnum, data_for_recexp_t &DATA)
{
        switch(testnum)
        {
        case 1: {set_test_1_data(DATA); break;}
        case 2: {set_test_2_data(DATA); break;}
        case 3: {set_test_3_data(DATA); break;}
        case 4: {set_test_4_data(DATA); break;}
        case 5: {set_test_5_data(DATA); break;}
        case 6: {set_test_6_data(DATA); break;}
        case 7: {set_test_7_data(DATA); break;}
        case 8: {set_test_8_data(DATA); break;}
        //case 9: {set_test_9_data(DATA); break;}
        default:
                throw std::runtime_error("wrong value of testnum in prepare_data_for_recexp");
        }

        printf("N = %d\n", DATA.N);
        printf("N_occ = %d\n", DATA.N_occ);
        printf("err_sub = %g\n", (double)DATA.err_sub);
        printf("err_eig = %g\n", (double)DATA.err_eig);
        printf("Chosen norm for the stopping criterion: ");
        switch (DATA.normPuriStopCrit)
        {
        case mat::mixedNorm:
                printf("mixed\n");
                break;

        case mat::euclNorm:
                printf("eucl\n");
                break;

        case mat::frobNorm:
                printf("frob\n");
                break;

        default:
                throw std::runtime_error("Unknown norm in prepare_data_for_recexp");
        }

        printf("Chosen norm for the truncation: ");
        switch (DATA.normPuri)
        {
        case mat::mixedNorm:
                printf("mixed\n");
                break;

        case mat::euclNorm:
                printf("eucl\n");
                break;

        case mat::frobNorm:
                printf("frob\n");
                break;

        default:
                throw std::runtime_error("Unknown norm in prepare_data_for_recexp");
        }
        if(DATA.which_stop_crit == 0)
                printf("Use old stopping criterion.\n");
        else
                printf("Use new stopping criterion.\n");



}


/*
TEST 1
Nocc is approx N/2, homo-lumo bounds are known, use spectral norm for truncation and stopping criterion
*/
void set_test_1_data(data_for_recexp_t &DATA)
{
        DATA.N = 123;
        DATA.N_occ = 50;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = -1; //dummy
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = true;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;

        DATA.normPuri = mat::euclNorm;
        DATA.normPuriStopCrit = mat::euclNorm;

        DATA.which_stop_crit = 1; // new

        DATA.compute_eigenvectors = false;

}

/*
TEST 2
Nocc is approx N/2, homo-lumo bounds are NOT known, use spectral norm for truncation and stopping criterion
*/
void set_test_2_data(data_for_recexp_t &DATA)
{
        DATA.N = 123;
        DATA.N_occ = 50;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = -1; //dummy
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = false;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;

        DATA.normPuri = mat::euclNorm;
        DATA.normPuriStopCrit = mat::euclNorm;

        DATA.which_stop_crit = 1; // new

        DATA.compute_eigenvectors = false;

}

/*
TEST 3
Matrix of 1 element, nocc is 1, homo-lumo bounds are NOT known, use spectral norm for truncation and stopping criterion
*/
void set_test_3_data(data_for_recexp_t &DATA)
{
        DATA.N = 1;
        DATA.N_occ = 1;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = -1; //dummy
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = false;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;

        DATA.normPuri = mat::euclNorm;
        DATA.normPuriStopCrit = mat::euclNorm;

        DATA.which_stop_crit = 1; // new

        DATA.compute_eigenvectors = false;

}


/*
TEST 4
Nocc is small compared to N, homo-lumo bounds are known, use Frobenius norm for truncation and stopping criterion
*/
void set_test_4_data(data_for_recexp_t &DATA)
{
        DATA.N = 87;
        DATA.N_occ = 5;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = -1; //dummy
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = true;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;

        DATA.normPuri = mat::frobNorm;
        DATA.normPuriStopCrit = mat::frobNorm;

        DATA.which_stop_crit = 1; // new

        DATA.compute_eigenvectors = false;

}


/*
TEST 5
Nocc is small compared to N, homo-lumo bounds are NOT known, use Frobenius norm for truncation and stopping criterion
*/
void set_test_5_data(data_for_recexp_t &DATA)
{
        DATA.N = 87;
        DATA.N_occ = 5;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = -1; //dummy
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = false;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;

        DATA.normPuri = mat::frobNorm;
        DATA.normPuriStopCrit = mat::frobNorm;

        DATA.which_stop_crit = 1; // new

        DATA.compute_eigenvectors = false;

}

/*
TEST 6
Nocc is small compared to N, homo-lumo bounds are known, use mixed norm for truncation and stopping criterion
*/
void set_test_6_data(data_for_recexp_t &DATA)
{
        DATA.N = 87;
        DATA.N_occ = 5;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = -1; //dummy
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = true;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;

        DATA.normPuri = mat::mixedNorm;
        DATA.normPuriStopCrit = mat::mixedNorm;

        DATA.which_stop_crit = 1; // new

        DATA.compute_eigenvectors = false;

}


/*
TEST 7
Nocc is small compared to N, homo-lumo bounds are NOT known, use mixed norm for truncation and stopping criterion
*/
void set_test_7_data(data_for_recexp_t &DATA)
{
        DATA.N = 87;
        DATA.N_occ = 5;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = -1; //dummy
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = false;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;
        DATA.normPuri = mat::mixedNorm;
        DATA.normPuriStopCrit = mat::mixedNorm;
        DATA.which_stop_crit = 1; // new
        DATA.compute_eigenvectors = false;
}

/*
TEST 8
Nocc is small compared to N, homo-lumo bounds are NOT known, use mixed norm for truncation and stopping criterion, use old stopping criterion
*/
void set_test_8_data(data_for_recexp_t &DATA)
{
        DATA.N = 87;
        DATA.N_occ = 5;
        DATA.err_sub = TOL_ERR_SUBS_DEFAULT;
        DATA.err_eig = TOL_ERR_SUBS_DEFAULT/10;
        DATA.maxit = 100;
        DATA.homo_lumo_bounds_known = false;
        // values of homo and lumo will be estimated in the main function
        DATA.homo_out = -MAX_DOUBLE;
        DATA.homo_in = MAX_DOUBLE;
        DATA.lumo_in = -MAX_DOUBLE;
        DATA.lumo_out = MAX_DOUBLE;
        DATA.normPuri = mat::mixedNorm;
        DATA.normPuriStopCrit = mat::mixedNorm;
        DATA.which_stop_crit = 0; // old
        DATA.compute_eigenvectors = false;
}




/*
TEST 9
N is 10, nocc is 5, homo-lumo bounds are known, use mixed norm for truncation and stopping criterion, check that number of iterations is small

Anastasia note: test is removed since on different machine rand generates different numbers. The random matrices are different and thus the purification requires a different number of iterations.
*/
// void set_test_9_data(data_for_recexp_t &DATA)
// {
//         DATA.N = 10;
//         DATA.N_occ = 5;
//         DATA.err_sub = 0;  // no truncation
//         DATA.err_eig = -1; //dummy;
//         DATA.maxit = 23; // remind that in quad precision we may need more iterations
//         DATA.homo_lumo_bounds_known = true;
//         // values of homo and lumo will be estimated in the main function
//         DATA.homo_out = -MAX_DOUBLE;
//         DATA.homo_in = MAX_DOUBLE;
//         DATA.lumo_in = -MAX_DOUBLE;
//         DATA.lumo_out = MAX_DOUBLE;
//         DATA.normPuri = mat::mixedNorm;
//         DATA.normPuriStopCrit = mat::mixedNorm;
//         DATA.which_stop_crit = 1; // new
//         DATA.compute_eigenvectors = false;
// }



#endif
