diff options
Diffstat (limited to 'eigen/Eigen/src/Core/SolveTriangular.h')
-rw-r--r-- | eigen/Eigen/src/Core/SolveTriangular.h | 77 |
1 files changed, 24 insertions, 53 deletions
diff --git a/eigen/Eigen/src/Core/SolveTriangular.h b/eigen/Eigen/src/Core/SolveTriangular.h index 30c9c38..a0011d4 100644 --- a/eigen/Eigen/src/Core/SolveTriangular.h +++ b/eigen/Eigen/src/Core/SolveTriangular.h @@ -68,7 +68,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1> if(!useRhsDirectly) MappedRhs(actualRhs,rhs.size()) = rhs; - triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate, + triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate, (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor> ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); @@ -82,7 +82,6 @@ template<typename Lhs, typename Rhs, int Side, int Mode> struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic> { typedef typename Rhs::Scalar Scalar; - typedef typename Rhs::Index Index; typedef blas_traits<Lhs> LhsProductTraits; typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; @@ -96,7 +95,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic> typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType; - BlockingType blocking(rhs.rows(), rhs.cols(), size); + BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false); triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor> @@ -108,32 +107,32 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic> * meta-unrolling implementation ***************************************************************************/ -template<typename Lhs, typename Rhs, int Mode, int Index, int Size, - bool Stop = Index==Size> +template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size, + bool Stop = LoopIndex==Size> struct triangular_solver_unroller; -template<typename Lhs, typename Rhs, int Mode, int Index, int Size> -struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> { +template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size> +struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,false> { enum { IsLower = ((Mode&Lower)==Lower), - RowIndex = IsLower ? Index : Size - Index - 1, - S = IsLower ? 0 : RowIndex+1 + DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1, + StartIndex = IsLower ? 0 : DiagIndex+1 }; static void run(const Lhs& lhs, Rhs& rhs) { - if (Index>0) - rhs.coeffRef(RowIndex) -= lhs.row(RowIndex).template segment<Index>(S).transpose() - .cwiseProduct(rhs.template segment<Index>(S)).sum(); + if (LoopIndex>0) + rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex).template segment<LoopIndex>(StartIndex).transpose() + .cwiseProduct(rhs.template segment<LoopIndex>(StartIndex)).sum(); if(!(Mode & UnitDiag)) - rhs.coeffRef(RowIndex) /= lhs.coeff(RowIndex,RowIndex); + rhs.coeffRef(DiagIndex) /= lhs.coeff(DiagIndex,DiagIndex); - triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs); + triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex+1,Size>::run(lhs,rhs); } }; -template<typename Lhs, typename Rhs, int Mode, int Index, int Size> -struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> { +template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size> +struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,true> { static void run(const Lhs&, Rhs&) {} }; @@ -162,61 +161,35 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> { * TriangularView methods ***************************************************************************/ -/** "in-place" version of TriangularView::solve() where the result is written in \a other - * - * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. - * This function will const_cast it, so constness isn't honored here. - * - * See TriangularView:solve() for the details. - */ +#ifndef EIGEN_PARSED_BY_DOXYGEN template<typename MatrixType, unsigned int Mode> template<int Side, typename OtherDerived> -void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived>& _other) const +EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const { OtherDerived& other = _other.const_cast_derived(); - eigen_assert( cols() == rows() && ((Side==OnTheLeft && cols() == other.rows()) || (Side==OnTheRight && cols() == other.cols())) ); + eigen_assert( derived().cols() == derived().rows() && ((Side==OnTheLeft && derived().cols() == other.rows()) || (Side==OnTheRight && derived().cols() == other.cols())) ); eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); - enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit && OtherDerived::IsVectorAtCompileTime }; + enum { copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime!=1}; typedef typename internal::conditional<copy, typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; OtherCopy otherCopy(other); internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type, - Side, Mode>::run(nestedExpression(), otherCopy); + Side, Mode>::run(derived().nestedExpression(), otherCopy); if (copy) other = otherCopy; } -/** \returns the product of the inverse of \c *this with \a other, \a *this being triangular. - * - * This function computes the inverse-matrix matrix product inverse(\c *this) * \a other if - * \a Side==OnTheLeft (the default), or the right-inverse-multiply \a other * inverse(\c *this) if - * \a Side==OnTheRight. - * - * The matrix \c *this must be triangular and invertible (i.e., all the coefficients of the - * diagonal must be non zero). It works as a forward (resp. backward) substitution if \c *this - * is an upper (resp. lower) triangular matrix. - * - * Example: \include MatrixBase_marked.cpp - * Output: \verbinclude MatrixBase_marked.out - * - * This function returns an expression of the inverse-multiply and can works in-place if it is assigned - * to the same matrix or vector \a other. - * - * For users coming from BLAS, this function (and more specifically solveInPlace()) offer - * all the operations supported by the \c *TRSV and \c *TRSM BLAS routines. - * - * \sa TriangularView::solveInPlace() - */ template<typename Derived, unsigned int Mode> template<int Side, typename Other> const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other> -TriangularView<Derived,Mode>::solve(const MatrixBase<Other>& other) const +TriangularViewImpl<Derived,Mode,Dense>::solve(const MatrixBase<Other>& other) const { - return internal::triangular_solve_retval<Side,TriangularView,Other>(*this, other.derived()); + return internal::triangular_solve_retval<Side,TriangularViewType,Other>(derived(), other.derived()); } +#endif namespace internal { @@ -232,7 +205,6 @@ template<int Side, typename TriangularType, typename Rhs> struct triangular_solv { typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned; typedef ReturnByValue<triangular_solve_retval> Base; - typedef typename Base::Index Index; triangular_solve_retval(const TriangularType& tri, const Rhs& rhs) : m_triangularMatrix(tri), m_rhs(rhs) @@ -243,8 +215,7 @@ template<int Side, typename TriangularType, typename Rhs> struct triangular_solv template<typename Dest> inline void evalTo(Dest& dst) const { - const typename Dest::Scalar *dst_data = internal::extract_data(dst); - if(!(is_same<RhsNestedCleaned,Dest>::value && dst_data!=0 && extract_data(dst) == extract_data(m_rhs))) + if(!is_same_dense(dst,m_rhs)) dst = m_rhs; m_triangularMatrix.template solveInPlace<Side>(dst); } |