From 35f7829af10c61e33dd2e2a7a015058e11a11ea0 Mon Sep 17 00:00:00 2001 From: Stanislaw Halik Date: Sat, 25 Mar 2017 14:17:07 +0100 Subject: update --- .../test/cxx11_tensor_concatenation.cpp | 137 +++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 eigen/unsupported/test/cxx11_tensor_concatenation.cpp (limited to 'eigen/unsupported/test/cxx11_tensor_concatenation.cpp') diff --git a/eigen/unsupported/test/cxx11_tensor_concatenation.cpp b/eigen/unsupported/test/cxx11_tensor_concatenation.cpp new file mode 100644 index 0000000..03ef12e --- /dev/null +++ b/eigen/unsupported/test/cxx11_tensor_concatenation.cpp @@ -0,0 +1,137 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include + +using Eigen::Tensor; + +template +static void test_dimension_failures() +{ + Tensor left(2, 3, 1); + Tensor right(3, 3, 1); + left.setRandom(); + right.setRandom(); + + // Okay; other dimensions are equal. + Tensor concatenation = left.concatenate(right, 0); + + // Dimension mismatches. + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1)); + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2)); + + // Axis > NumDims or < 0. + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3)); + VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1)); +} + +template +static void test_static_dimension_failure() +{ + Tensor left(2, 3); + Tensor right(2, 3, 1); + +#ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE + // Technically compatible, but we static assert that the inputs have same + // NumDims. + Tensor concatenation = left.concatenate(right, 0); +#endif + + // This can be worked around in this case. + Tensor concatenation = left + .reshape(Tensor::Dimensions(2, 3, 1)) + .concatenate(right, 0); + Tensor alternative = left + .concatenate(right.reshape(Tensor::Dimensions{{{2, 3}}}), 0); +} + +template +static void test_simple_concatenation() +{ + Tensor left(2, 3, 1); + Tensor right(2, 3, 1); + left.setRandom(); + right.setRandom(); + + Tensor concatenation = left.concatenate(right, 0); + VERIFY_IS_EQUAL(concatenation.dimension(0), 4); + VERIFY_IS_EQUAL(concatenation.dimension(1), 3); + VERIFY_IS_EQUAL(concatenation.dimension(2), 1); + for (int j = 0; j < 3; ++j) { + for (int i = 0; i < 2; ++i) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + } + for (int i = 2; i < 4; ++i) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0)); + } + } + + concatenation = left.concatenate(right, 1); + VERIFY_IS_EQUAL(concatenation.dimension(0), 2); + VERIFY_IS_EQUAL(concatenation.dimension(1), 6); + VERIFY_IS_EQUAL(concatenation.dimension(2), 1); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + } + for (int j = 3; j < 6; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0)); + } + } + + concatenation = left.concatenate(right, 2); + VERIFY_IS_EQUAL(concatenation.dimension(0), 2); + VERIFY_IS_EQUAL(concatenation.dimension(1), 3); + VERIFY_IS_EQUAL(concatenation.dimension(2), 2); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); + VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0)); + } + } +} + + +// TODO(phli): Add test once we have a real vectorized implementation. +// static void test_vectorized_concatenation() {} + +static void test_concatenation_as_lvalue() +{ + Tensor t1(2, 3); + Tensor t2(2, 3); + t1.setRandom(); + t2.setRandom(); + + Tensor result(4, 3); + result.setRandom(); + t1.concatenate(t2, 0) = result; + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(t1(i, j), result(i, j)); + VERIFY_IS_EQUAL(t2(i, j), result(i+2, j)); + } + } +} + + +void test_cxx11_tensor_concatenation() +{ + CALL_SUBTEST(test_dimension_failures()); + CALL_SUBTEST(test_dimension_failures()); + CALL_SUBTEST(test_static_dimension_failure()); + CALL_SUBTEST(test_static_dimension_failure()); + CALL_SUBTEST(test_simple_concatenation()); + CALL_SUBTEST(test_simple_concatenation()); + // CALL_SUBTEST(test_vectorized_concatenation()); + CALL_SUBTEST(test_concatenation_as_lvalue()); + +} -- cgit v1.2.3