// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
#define EIGEN_ITERATIVE_SOLVER_BASE_H

// IWYU pragma: private
#include "./InternalHeaderCheck.h"

namespace Eigen {

namespace internal {

template <typename MatrixType>
struct is_ref_compatible_impl {
 private:
  template <typename T0>
  struct any_conversion {
    template <typename T>
    any_conversion(const volatile T&);
    template <typename T>
    any_conversion(T&);
  };
  struct yes {
    int a[1];
  };
  struct no {
    int a[2];
  };

  template <typename T>
  static yes test(const Ref<const T>&, int);
  template <typename T>
  static no test(any_conversion<T>, ...);

 public:
  static MatrixType ms_from;
  enum { value = sizeof(test<MatrixType>(ms_from, 0)) == sizeof(yes) };
};

template <typename MatrixType>
struct is_ref_compatible {
  enum { value = is_ref_compatible_impl<remove_all_t<MatrixType>>::value };
};

template <typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value>
class generic_matrix_wrapper;

// We have an explicit matrix at hand, compatible with Ref<>
template <typename MatrixType>
class generic_matrix_wrapper<MatrixType, false> {
 public:
  typedef Ref<const MatrixType> ActualMatrixType;
  template <int UpLo>
  struct ConstSelfAdjointViewReturnType {
    typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
  };

  enum { MatrixFree = false };

  generic_matrix_wrapper() : m_dummy(0, 0), m_matrix(m_dummy) {}

  template <typename InputType>
  generic_matrix_wrapper(const InputType& mat) : m_matrix(mat) {}

  const ActualMatrixType& matrix() const { return m_matrix; }

  template <typename MatrixDerived>
  void grab(const EigenBase<MatrixDerived>& mat) {
    internal::destroy_at(&m_matrix);
    internal::construct_at(&m_matrix, mat.derived());
  }

  void grab(const Ref<const MatrixType>& mat) {
    if (&(mat.derived()) != &m_matrix) {
      internal::destroy_at(&m_matrix);
      internal::construct_at(&m_matrix, mat);
    }
  }

 protected:
  MatrixType m_dummy;  // used to default initialize the Ref<> object
  ActualMatrixType m_matrix;
};

// MatrixType is not compatible with Ref<> -> matrix-free wrapper
template <typename MatrixType>
class generic_matrix_wrapper<MatrixType, true> {
 public:
  typedef MatrixType ActualMatrixType;
  template <int UpLo>
  struct ConstSelfAdjointViewReturnType {
    typedef ActualMatrixType Type;
  };

  enum { MatrixFree = true };

  generic_matrix_wrapper() : mp_matrix(0) {}

  generic_matrix_wrapper(const MatrixType& mat) : mp_matrix(&mat) {}

  const ActualMatrixType& matrix() const { return *mp_matrix; }

  void grab(const MatrixType& mat) { mp_matrix = &mat; }

 protected:
  const ActualMatrixType* mp_matrix;
};

}  // namespace internal

/** \ingroup IterativeLinearSolvers_Module
 * \brief Base class for linear iterative solvers
 *
 * \sa class SimplicialCholesky, DiagonalPreconditioner, IdentityPreconditioner
 */
template <typename Derived>
class IterativeSolverBase : public SparseSolverBase<Derived> {
 protected:
  typedef SparseSolverBase<Derived> Base;
  using Base::m_isInitialized;

 public:
  typedef typename internal::traits<Derived>::MatrixType MatrixType;
  typedef typename internal::traits<Derived>::Preconditioner Preconditioner;
  typedef typename MatrixType::Scalar Scalar;
  typedef typename MatrixType::StorageIndex StorageIndex;
  typedef typename MatrixType::RealScalar RealScalar;

  enum { ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime };

 public:
  using Base::derived;

  /** Default constructor. */
  IterativeSolverBase() { init(); }

  /** Initialize the solver with matrix \a A for further \c Ax=b solving.
   *
   * This constructor is a shortcut for the default constructor followed
   * by a call to compute().
   *
   * \warning this class stores a reference to the matrix A as well as some
   * precomputed values that depend on it. Therefore, if \a A is changed
   * this class becomes invalid. Call compute() to update it with the new
   * matrix A, or modify a copy of A.
   */
  template <typename MatrixDerived>
  explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A) : m_matrixWrapper(A.derived()) {
    init();
    compute(matrix());
  }

  IterativeSolverBase(IterativeSolverBase&&) = default;

  ~IterativeSolverBase() {}

  /** Initializes the iterative solver for the sparsity pattern of the matrix \a A for further solving \c Ax=b problems.
   *
   * Currently, this function mostly calls analyzePattern on the preconditioner. In the future
   * we might, for instance, implement column reordering for faster matrix vector products.
   */
  template <typename MatrixDerived>
  Derived& analyzePattern(const EigenBase<MatrixDerived>& A) {
    grab(A.derived());
    m_preconditioner.analyzePattern(matrix());
    m_isInitialized = true;
    m_analysisIsOk = true;
    m_info = m_preconditioner.info();
    return derived();
  }

  /** Initializes the iterative solver with the numerical values of the matrix \a A for further solving \c Ax=b
   * problems.
   *
   * Currently, this function mostly calls factorize on the preconditioner.
   *
   * \warning this class stores a reference to the matrix A as well as some
   * precomputed values that depend on it. Therefore, if \a A is changed
   * this class becomes invalid. Call compute() to update it with the new
   * matrix A, or modify a copy of A.
   */
  template <typename MatrixDerived>
  Derived& factorize(const EigenBase<MatrixDerived>& A) {
    eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
    grab(A.derived());
    m_preconditioner.factorize(matrix());
    m_factorizationIsOk = true;
    m_info = m_preconditioner.info();
    return derived();
  }

  /** Initializes the iterative solver with the matrix \a A for further solving \c Ax=b problems.
   *
   * Currently, this function mostly initializes/computes the preconditioner. In the future
   * we might, for instance, implement column reordering for faster matrix vector products.
   *
   * \warning this class stores a reference to the matrix A as well as some
   * precomputed values that depend on it. Therefore, if \a A is changed
   * this class becomes invalid. Call compute() to update it with the new
   * matrix A, or modify a copy of A.
   */
  template <typename MatrixDerived>
  Derived& compute(const EigenBase<MatrixDerived>& A) {
    grab(A.derived());
    m_preconditioner.compute(matrix());
    m_isInitialized = true;
    m_analysisIsOk = true;
    m_factorizationIsOk = true;
    m_info = m_preconditioner.info();
    return derived();
  }

  /** \internal */
  EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return matrix().rows(); }

  /** \internal */
  EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return matrix().cols(); }

  /** \returns the tolerance threshold used by the stopping criteria.
   * \sa setTolerance()
   */
  RealScalar tolerance() const { return m_tolerance; }

  /** Sets the tolerance threshold used by the stopping criteria.
   *
   * This value is used as an upper bound to the relative residual error: |Ax-b|/|b|.
   * The default value is the machine precision given by NumTraits<Scalar>::epsilon()
   */
  Derived& setTolerance(const RealScalar& tolerance) {
    m_tolerance = tolerance;
    return derived();
  }

  /** \returns a read-write reference to the preconditioner for custom configuration. */
  Preconditioner& preconditioner() { return m_preconditioner; }

  /** \returns a read-only reference to the preconditioner. */
  const Preconditioner& preconditioner() const { return m_preconditioner; }

  /** \returns the max number of iterations.
   * It is either the value set by setMaxIterations or, by default,
   * twice the number of columns of the matrix.
   */
  Index maxIterations() const { return (m_maxIterations < 0) ? 2 * matrix().cols() : m_maxIterations; }

  /** Sets the max number of iterations.
   * Default is twice the number of columns of the matrix.
   */
  Derived& setMaxIterations(Index maxIters) {
    m_maxIterations = maxIters;
    return derived();
  }

  /** \returns the number of iterations performed during the last solve */
  Index iterations() const {
    eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
    return m_iterations;
  }

  /** \returns the tolerance error reached during the last solve.
   * It is a close approximation of the true relative residual error |Ax-b|/|b|.
   */
  RealScalar error() const {
    eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
    return m_error;
  }

  /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A
   * and \a x0 as an initial solution.
   *
   * \sa solve(), compute()
   */
  template <typename Rhs, typename Guess>
  inline const SolveWithGuess<Derived, Rhs, Guess> solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const {
    eigen_assert(m_isInitialized && "Solver is not initialized.");
    eigen_assert(derived().rows() == b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
    return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
  }

  /** \returns Success if the iterations converged, and NoConvergence otherwise. */
  ComputationInfo info() const {
    eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
    return m_info;
  }

  /** \internal */
  template <typename Rhs, typename DestDerived>
  void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived>& aDest) const {
    eigen_assert(rows() == b.rows());

    Index rhsCols = b.cols();
    Index size = b.rows();
    DestDerived& dest(aDest.derived());
    typedef typename DestDerived::Scalar DestScalar;
    Eigen::Matrix<DestScalar, Dynamic, 1> tb(size);
    Eigen::Matrix<DestScalar, Dynamic, 1> tx(cols());
    // We do not directly fill dest because sparse expressions have to be free of aliasing issue.
    // For non square least-square problems, b and dest might not have the same size whereas they might alias
    // each-other.
    typename DestDerived::PlainObject tmp(cols(), rhsCols);
    ComputationInfo global_info = Success;
    for (Index k = 0; k < rhsCols; ++k) {
      tb = b.col(k);
      tx = dest.col(k);
      derived()._solve_vector_with_guess_impl(tb, tx);
      tmp.col(k) = tx.sparseView(0);

      // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
      // we need to restore it to the worst value.
      if (m_info == NumericalIssue)
        global_info = NumericalIssue;
      else if (m_info == NoConvergence)
        global_info = NoConvergence;
    }
    m_info = global_info;
    dest.swap(tmp);
  }

  template <typename Rhs, typename DestDerived>
  std::enable_if_t<Rhs::ColsAtCompileTime != 1 && DestDerived::ColsAtCompileTime != 1> _solve_with_guess_impl(
      const Rhs& b, MatrixBase<DestDerived>& aDest) const {
    eigen_assert(rows() == b.rows());

    Index rhsCols = b.cols();
    DestDerived& dest(aDest.derived());
    ComputationInfo global_info = Success;
    for (Index k = 0; k < rhsCols; ++k) {
      typename DestDerived::ColXpr xk(dest, k);
      typename Rhs::ConstColXpr bk(b, k);
      derived()._solve_vector_with_guess_impl(bk, xk);

      // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
      // we need to restore it to the worst value.
      if (m_info == NumericalIssue)
        global_info = NumericalIssue;
      else if (m_info == NoConvergence)
        global_info = NoConvergence;
    }
    m_info = global_info;
  }

  template <typename Rhs, typename DestDerived>
  std::enable_if_t<Rhs::ColsAtCompileTime == 1 || DestDerived::ColsAtCompileTime == 1> _solve_with_guess_impl(
      const Rhs& b, MatrixBase<DestDerived>& dest) const {
    derived()._solve_vector_with_guess_impl(b, dest.derived());
  }

  /** \internal default initial guess = 0 */
  template <typename Rhs, typename Dest>
  void _solve_impl(const Rhs& b, Dest& x) const {
    x.setZero();
    derived()._solve_with_guess_impl(b, x);
  }

 protected:
  void init() {
    m_isInitialized = false;
    m_analysisIsOk = false;
    m_factorizationIsOk = false;
    m_maxIterations = -1;
    m_tolerance = NumTraits<Scalar>::epsilon();
  }

  typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
  typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;

  const ActualMatrixType& matrix() const { return m_matrixWrapper.matrix(); }

  template <typename InputType>
  void grab(const InputType& A) {
    m_matrixWrapper.grab(A);
  }

  MatrixWrapper m_matrixWrapper;
  Preconditioner m_preconditioner;

  Index m_maxIterations;
  RealScalar m_tolerance;

  mutable RealScalar m_error;
  mutable Index m_iterations;
  mutable ComputationInfo m_info;
  mutable bool m_analysisIsOk, m_factorizationIsOk;
};

}  // end namespace Eigen

#endif  // EIGEN_ITERATIVE_SOLVER_BASE_H
