简而言之:我试图推断对矩阵执行一元运算的函数的返回值(这是一个表达式模板)。
在这种情况下,操作是计算协方差矩阵。
我在这里遵循了 Eigen 文档:
https ://eigen.tuxfamily.org/dox/TopicCustomizing_NullaryExpr.html
并创建了一个将矩阵乘以 2 并返回其结果的最小示例。下面的代码片段显示了一个工作示例。对我来说,关键点是表达式不会被评估为中间结果,所以我不想返回类似Eigen::MatrixBase<Derived>.
template<typename ArgType>
struct times_two_helper {
template<typename Derived>
static auto TimesTwo(Eigen::MatrixBase<Derived> const& mat) {
return mat + mat;
}
using ResultType = Eigen::Matrix<typename ArgType::Scalar,
ArgType::ColsAtCompileTime,
ArgType::RowsAtCompileTime,
ArgType::Options,
ArgType::MaxColsAtCompileTime,
ArgType::MaxRowsAtCompileTime>;
using ExpressionType = decltype(TimesTwo(std::declval<ArgType>()));
};
template<typename ArgType>
struct times_two_functor {
using ResultType = typename times_two_helper<ArgType>::ResultType;
using ExpressionType = typename times_two_helper<ArgType>::ExpressionType;
ArgType const& arg_;
// Here I want to store the expression without evaluating it!
ExpressionType expression_;
public:
times_two_functor(ArgType const& arg)
: arg_{arg}
, expression_{times_two_helper<ArgType>::TimesTwo(arg)}
{}
typename ArgType::Scalar operator() (Eigen::Index row, Eigen::Index col) const {
return expression_(row, col);
}
};
template <class ArgType>
Eigen::CwiseNullaryOp<times_two_functor<ArgType>, typename times_two_helper<ArgType>::ResultType>
TimesTwo(Eigen::MatrixBase<ArgType> const& arg) {
using ResultType = typename times_two_helper<ArgType>::ResultType;
return ResultType::NullaryExpr(arg.rows(), arg.cols(), times_two_functor<ArgType>(arg.derived()));
}
像这样使用:
TEST(Stat, TimesTwo) {
Eigen::Matrix<double, 3, 3> input;
input << 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0;
Eigen::Matrix<double, 3, 3> result = TimesTwo(input);
std::cout << result << "\n";
}
但是当我尝试对协方差矩阵做同样的事情时,它会推断出错误的类型ExpressionType(错误消息在底部给出)。
template<class ArgType>
struct covariance_helper {
template<typename Derived>
static auto Covariance(Eigen::MatrixBase<Derived> const& mat) {
auto centered = mat.rowwise() - mat.colwise().mean();
return (centered.adjoint() * centered) / double(mat.rows() - 1);
}
using ResultType = Eigen::Matrix<typename ArgType::Scalar,
ArgType::ColsAtCompileTime,
ArgType::ColsAtCompileTime,
ArgType::Options,
ArgType::MaxColsAtCompileTime,
ArgType::MaxColsAtCompileTime>;
using ExpressionType = decltype(covariance_helper::Covariance(std::declval<ArgType>()));
};
template<class ArgType>
class covariance_functor {
using ResultType = typename times_two_helper<ArgType>::ResultType;
using ExpressionType = typename times_two_helper<ArgType>::ExpressionType;
const ArgType &mat_;
ExpressionType expression_;
public:
covariance_functor(const ArgType& arg)
: mat_{arg}
, expression_(covariance_helper<ArgType>::Covariance(arg))
{}
typename ArgType::Scalar operator() (Eigen::Index row, Eigen::Index col) const {
return expression_(row, col);
}
};
template <class ArgType>
Eigen::CwiseNullaryOp<covariance_functor<ArgType>, typename covariance_helper<ArgType>::ResultType>
Covariance(Eigen::MatrixBase<ArgType> const& arg) {
using ResultType = typename covariance_helper<ArgType>::ResultType;
return ResultType::NullaryExpr(arg.cols(), arg.cols(), covariance_functor<ArgType>(arg.derived()));
}
像这样称呼它:
TEST(Stat, Covariance) {
constexpr double kPrecision = 1e-12;
// source: https://www.itl.nist.gov/div898/handbook/pmc/section5/pmc541.htm
Eigen::Matrix<double, 5, 3> mat;
mat << 4.0, 2.0, 0.6, 4.2, 2.1, 0.59, 3.9, 2.0, 0.58, 4.3, 2.1, 0.62, 4.1, 2.2, 0.63;
Eigen::Matrix<double, 3, 3> cov_expected;
cov_expected << 0.025, 0.0075, 0.00175, 0.0075, 0.0070, 0.00135, 0.00175, 0.00135, 0.00043;
Covariance(mat);
}
给我以下错误信息:
In file included from /test/test_stat.cpp:8:0:
/lib/core/stat.hpp: In instantiation of ‘covariance_functor<ArgType>::covariance_functor(const ArgType&) [with ArgType = Eigen::Matrix<double, 5, 3>]’:
/lib/core/stat.hpp:138:60: required from ‘Eigen::CwiseNullaryOp<covariance_functor<ArgType>, typename covariance_helper<ArgType>::ResultType> Covariance(const Eigen::MatrixBase<Derived>&) [with ArgType = Eigen::Matrix<double, 5, 3>; typename covariance_helper<ArgType>::ResultType = Eigen::Matrix<double, 3, 3>]’
/test/test_stat.cpp:19:29: required from here
/lib/core/stat.hpp:127:66: error: no matching function for call to ‘Eigen::CwiseBinaryOp<Eigen::internal::scalar_sum_op<double, double>, const Eigen::Matrix<double, 5, 3>, const Eigen::Matrix<double, 5, 3> >::CwiseBinaryOp(Eigen::CwiseBinaryOp<Eigen::internal::scalar_quotient_op<double, double>, const Eigen::Product<Eigen::Transpose<const Eigen::CwiseBinaryOp<Eigen::internal::scalar_difference_op<double, double>, const Eigen::Matrix<double, 5, 3>, const Eigen::Replicate<Eigen::PartialReduxExpr<const Eigen::Matrix<double, 5, 3>, Eigen::internal::member_mean<double>, 0>, 5, 1> > >, Eigen::CwiseBinaryOp<Eigen::internal::scalar_difference_op<double, double>, const Eigen::Matrix<double, 5, 3>, const Eigen::Replicate<Eigen::PartialReduxExpr<const Eigen::Matrix<double, 5, 3>, Eigen::internal::member_mean<double>, 0>, 5, 1> >, 0>, const Eigen::CwiseNullaryOp<Eigen::internal::scalar_constant_op<double>, const Eigen::Matrix<double, 3, 3> > >)’
, expression_(covariance_helper<ArgType>::Covariance(arg))
我在这里想念什么?我正在尝试做的事情可能吗?在延迟评估(特定于协方差矩阵)的情况下想要这个是否合理?
这对我来说是一个练习(不是家庭作业),所以我真的很想知道是否有可能以这种方式推导出表达式模板的完整类型。
我当然也对这种方法的实用性感兴趣,但程度较小。
编辑:
由于这可能对某人有所帮助,因此这是一个工作版本。
template<class ArgType>
struct covariance_helper {
template<typename Derived>
static auto Covariance(Eigen::MatrixBase<Derived> const& mat) {
auto centered = mat.rowwise() - mat.colwise().mean();
return (centered.adjoint() * centered) / double(mat.rows() - 1);
}
using ResultType = Eigen::Matrix<typename ArgType::Scalar,
ArgType::ColsAtCompileTime,
ArgType::ColsAtCompileTime,
ArgType::Options,
ArgType::MaxColsAtCompileTime,
ArgType::MaxColsAtCompileTime>;
using ExpressionType = decltype(covariance_helper::Covariance(std::declval<Eigen::MatrixBase<ArgType> const&>()));
};
template<class ArgType>
class covariance_functor {
using ResultType = typename covariance_helper<ArgType>::ResultType;
using ExpressionType = typename covariance_helper<ArgType>::ExpressionType;
const ArgType &mat_;
ExpressionType expression_;
public:
covariance_functor(const ArgType& arg)
: mat_{arg}
, expression_(covariance_helper<ArgType>::Covariance(arg))
{}
typename ArgType::Scalar operator() (Eigen::Index row, Eigen::Index col) const {
return expression_(row, col);
}
};
template <class ArgType>
Eigen::CwiseNullaryOp<covariance_functor<ArgType>, typename covariance_helper<ArgType>::ResultType>
Covariance(Eigen::MatrixBase<ArgType> const& arg) {
using ResultType = typename covariance_helper<ArgType>::ResultType;
return ResultType::NullaryExpr(arg.cols(), arg.cols(), covariance_functor<ArgType>(arg.derived()));
}
像这样称呼它:
TEST(Stat, Covariance) {
constexpr double kPrecision = 1e-12;
// source: https://www.itl.nist.gov/div898/handbook/pmc/section5/pmc541.htm
Eigen::Matrix<double, 5, 3> mat;
mat << 4.0, 2.0, 0.6, 4.2, 2.1, 0.59, 3.9, 2.0, 0.58, 4.3, 2.1, 0.62, 4.1, 2.2, 0.63;
Eigen::Matrix<double, 3, 3> cov_expected;
cov_expected << 0.025, 0.0075, 0.00175, 0.0075, 0.0070, 0.00135, 0.00175, 0.00135, 0.00043;
Eigen::MatrixXd cov = Covariance(mat);
ASSERT_NEAR((cov - cov_expected).norm(), 0.0, kPrecision);
}
或者更简单的:
template<typename Derived>
struct Covariance {
using MBase = Eigen::MatrixBase<Derived> const&;
constexpr static auto compute(MBase mat) {
auto centered = mat.rowwise() - mat.colwise().mean();
return (centered.adjoint() * centered) / double(mat.rows() - 1);
}
using ResultExpr = decltype(compute(std::declval<MBase>()));
Covariance(MBase mat)
: result_expr{compute(mat)}
{}
typename Derived::Scalar operator() (Eigen::Index row, Eigen::Index col) const {
return result_expr(row, col);
}
ResultExpr result_expr;
};
TEST(Stat, EigenUnaryExpr) {
constexpr double kPrecision = 1e-12;
// source: https://www.itl.nist.gov/div898/handbook/pmc/section5/pmc541.htm
Eigen::Matrix<double, 5, 3> mat;
mat << 4.0, 2.0, 0.6, 4.2, 2.1, 0.59, 3.9, 2.0, 0.58, 4.3, 2.1, 0.62, 4.1, 2.2, 0.63;
Eigen::Matrix<double, 3, 3> cov_expected;
cov_expected << 0.025, 0.0075, 0.00175, 0.0075, 0.0070, 0.00135, 0.00175, 0.00135, 0.00043;
Covariance cov{mat};
using ResultExpr = decltype(cov)::ResultExpr;
auto cov_expr = ResultExpr::NullaryExpr(cov.result_expr.rows(), cov.result_expr.cols(), cov);
std::cerr << cov_expr << "\n";
ASSERT_NEAR((cov_expr - cov_expected).norm(), 0.0, kPrecision);
}