Skip to content

Commit 67a474d

Browse files
authored
fixes assignment bug (#197)
1 parent b15d062 commit 67a474d

File tree

8 files changed

+79
-20
lines changed

8 files changed

+79
-20
lines changed

include/tensorwrapper/detail_/dsl_base.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,11 @@ class DSLBase {
274274
/// Checks that @p output is a subset of @p input
275275
void assert_is_subset_(const label_type& output,
276276
const label_type& input) const {
277-
if(output.intersection(input).size() < output.unique_index_size())
277+
// Subset would have equality
278+
if(output.intersection(input).size() < output.unique_index_size()) {
278279
throw std::runtime_error(
279280
"Output indices must be a subset of input indices");
281+
}
280282
}
281283

282284
/// Asserts that @p lhs is a permutation of @p rhs

include/tensorwrapper/dsl/pairwise_parser.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ class PairwiseParser {
5959
*/
6060
template<typename LHSType, typename RHSType>
6161
void dispatch(LHSType&& lhs, const RHSType& rhs) {
62-
lhs.object().permute_assignment(lhs.labels(), rhs);
62+
if(lhs.labels().is_permutation(rhs.labels()))
63+
lhs.object().permute_assignment(lhs.labels(), rhs);
64+
else { // User just wants us to assign RHS to LHS
65+
lhs.labels() = rhs.labels();
66+
lhs.object().permute_assignment(rhs.labels(), rhs);
67+
}
6368
}
6469

6570
/** @brief Handles adding two expressions together.

tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
7979
auto scalar_layout = scalar_physical();
8080
auto vector_layout = vector_physical(2);
8181
auto matrix_layout = matrix_physical(2, 3);
82-
auto tensor_layout = tensor_physical(1, 2, 3);
82+
auto tensor_layout = tensor3_physical(1, 2, 3);
8383

8484
scalar_buffer scalar(eigen_scalar, scalar_layout);
8585
vector_buffer vector(eigen_vector, vector_layout);
@@ -276,7 +276,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
276276
auto tensor2 = testing::eigen_tensor3<TestType>();
277277

278278
std::array<int, 3> p102{1, 0, 2};
279-
auto l102 = testing::tensor_physical(2, 1, 3);
279+
auto l102 = testing::tensor3_physical(2, 1, 3);
280280
tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102);
281281

282282
auto tijk = tensor("i,j,k");
@@ -285,7 +285,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
285285
tensor2.addition_assignment("k,j,i", tijk, tjik);
286286

287287
std::array<int, 3> p210{2, 1, 0};
288-
auto l210 = testing::tensor_physical(3, 2, 1);
288+
auto l210 = testing::tensor3_physical(3, 2, 1);
289289
tensor_buffer corr(eigen_tensor.shuffle(p210), l210);
290290
corr.value()(0, 0, 0) = 20.0;
291291
corr.value()(0, 1, 0) = 80.0;
@@ -392,7 +392,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
392392
auto tensor2 = testing::eigen_tensor3<TestType>();
393393

394394
std::array<int, 3> p102{1, 0, 2};
395-
auto l102 = testing::tensor_physical(2, 1, 3);
395+
auto l102 = testing::tensor3_physical(2, 1, 3);
396396
tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102);
397397

398398
auto tijk = tensor("i,j,k");
@@ -401,7 +401,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
401401
tensor2.subtraction_assignment("k,j,i", tijk, tjik);
402402

403403
std::array<int, 3> p210{2, 1, 0};
404-
auto l210 = testing::tensor_physical(3, 2, 1);
404+
auto l210 = testing::tensor3_physical(3, 2, 1);
405405
tensor_buffer corr(eigen_tensor.shuffle(p210), l210);
406406
corr.value()(0, 0, 0) = 0.0;
407407
corr.value()(0, 1, 0) = 0.0;
@@ -631,7 +631,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
631631
auto tensor2 = testing::eigen_tensor3<TestType>();
632632

633633
std::array<int, 3> p102{1, 0, 2};
634-
auto l102 = testing::tensor_physical(2, 1, 3);
634+
auto l102 = testing::tensor3_physical(2, 1, 3);
635635
tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102);
636636

637637
auto tijk = tensor("i,j,k");
@@ -640,7 +640,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
640640
tensor2.multiplication_assignment("k,j,i", tijk, tjik);
641641

642642
std::array<int, 3> p210{2, 1, 0};
643-
auto l210 = testing::tensor_physical(3, 2, 1);
643+
auto l210 = testing::tensor3_physical(3, 2, 1);
644644
tensor_buffer corr(eigen_tensor.shuffle(p210), l210);
645645
corr.value()(0, 0, 0) = 100.0;
646646
corr.value()(0, 1, 0) = 1600.0;

tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ using namespace tensorwrapper;
2222
TEMPLATE_LIST_TEST_CASE("DSL", "", testing::dsl_types) {
2323
using object_type = TestType;
2424

25-
auto scalar_values = testing::scalar_values();
26-
auto vector_values = testing::vector_values();
27-
auto matrix_values = testing::matrix_values();
25+
auto scalar_values = testing::scalar_values();
26+
auto vector_values = testing::vector_values();
27+
auto matrix_values = testing::matrix_values();
28+
auto tensor4_values = testing::tensor4_values();
2829

2930
auto value0 = std::get<object_type>(scalar_values);
3031
auto value1 = std::get<object_type>(vector_values);
3132
auto value2 = std::get<object_type>(matrix_values);
33+
auto value4 = std::get<object_type>(tensor4_values);
3234

3335
SECTION("assignment") {
3436
value0("i,j") = value2("i,j");
@@ -61,6 +63,11 @@ TEMPLATE_LIST_TEST_CASE("DSL", "", testing::dsl_types) {
6163

6264
value1.multiplication_assignment("i,j", value2("i,j"), value2("i,j"));
6365
REQUIRE(value1.are_equal(value0));
66+
67+
value0("m,n") = value2("l,s") * value4("m,n,s,l");
68+
value1.multiplication_assignment("m,n", value2("l,s"),
69+
value4("m,n,s,l"));
70+
REQUIRE(value1.are_equal(value0));
6471
}
6572

6673
SECTION("scalar_multiplication") {

tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ TEST_CASE("DummyIndices<std::string>") {
280280
REQUIRE(matrix.concatenation(vector) == dummy_indices_type("i,j,i"));
281281
REQUIRE(matrix.concatenation(matrix) == dummy_indices_type("i,j,i,j"));
282282
REQUIRE(matrix.concatenation(matrix2) == dummy_indices_type("i,j,k,l"));
283+
284+
auto x = matrix.concatenation(dummy_indices_type("i,j,l,s"));
285+
REQUIRE(x == dummy_indices_type("i,j,i,j,l,s"));
283286
}
284287

285288
SECTION("intersection") {
@@ -298,6 +301,12 @@ TEST_CASE("DummyIndices<std::string>") {
298301
REQUIRE(matrix.intersection(vector) == dummy_indices_type("i"));
299302
REQUIRE(matrix.intersection(matrix) == dummy_indices_type("i,j"));
300303
REQUIRE(matrix.intersection(matrix2) == dummy_indices_type(""));
304+
305+
auto x = matrix.intersection(dummy_indices_type("i,j,l,s"));
306+
REQUIRE(x == dummy_indices_type("i,j"));
307+
308+
auto y = matrix.intersection(dummy_indices_type("i,j,i,j,l,s"));
309+
REQUIRE(x == dummy_indices_type("i,j"));
301310
}
302311

303312
SECTION("difference") {

tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,25 @@ inline auto matrix_values() {
5454
Tensor{{1.0, 2.0}, {3.0, 4.0}}};
5555
}
5656

57+
inline auto tensor3_values() {
58+
return dsl_types{
59+
smooth_tensor3(),
60+
tensorwrapper::symmetry::Group(3),
61+
tensorwrapper::sparsity::Pattern(3),
62+
tensor3_logical(),
63+
tensor3_physical(),
64+
Tensor{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}};
65+
}
66+
67+
inline auto tensor4_values() {
68+
return dsl_types{
69+
smooth_tensor4(),
70+
tensorwrapper::symmetry::Group(4),
71+
tensorwrapper::sparsity::Pattern(4),
72+
tensor4_logical(),
73+
tensor4_physical(),
74+
Tensor{{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}},
75+
{{{9.0, 10.0}, {11.0, 12.0}}, {{13.0, 14.0}, {15.0, 16.0}}}}};
76+
}
77+
5778
} // namespace tensorwrapper::testing

tests/cxx/unit_tests/tensorwrapper/testing/layouts.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,14 @@ inline auto matrix_logical(std::size_t i = 10, std::size_t j = 10) {
3636
return tensorwrapper::layout::Logical(smooth_matrix(i, j));
3737
}
3838

39-
inline auto tensor_logical(std::size_t i = 10, std::size_t j = 10,
40-
std::size_t k = 10) {
41-
return tensorwrapper::layout::Logical(smooth_tensor(i, j, k));
39+
inline auto tensor3_logical(std::size_t i = 10, std::size_t j = 10,
40+
std::size_t k = 10) {
41+
return tensorwrapper::layout::Logical(smooth_tensor3(i, j, k));
42+
}
43+
44+
inline auto tensor4_logical(std::size_t i = 10, std::size_t j = 10,
45+
std::size_t k = 10, std::size_t l = 10) {
46+
return tensorwrapper::layout::Logical(smooth_tensor4(i, j, k, l));
4247
}
4348

4449
// -----------------------------------------------------------------------------
@@ -57,9 +62,14 @@ inline auto matrix_physical(std::size_t i = 10, std::size_t j = 10) {
5762
return tensorwrapper::layout::Physical(smooth_matrix(i, j));
5863
}
5964

60-
inline auto tensor_physical(std::size_t i = 10, std::size_t j = 10,
61-
std::size_t k = 10) {
62-
return tensorwrapper::layout::Physical(smooth_tensor(i, j, k));
65+
inline auto tensor3_physical(std::size_t i = 10, std::size_t j = 10,
66+
std::size_t k = 10) {
67+
return tensorwrapper::layout::Physical(smooth_tensor3(i, j, k));
68+
}
69+
70+
inline auto tensor4_physical(std::size_t i = 10, std::size_t j = 10,
71+
std::size_t k = 10, std::size_t l = 10) {
72+
return tensorwrapper::layout::Physical(smooth_tensor4(i, j, k, l));
6373
}
6474

6575
} // namespace tensorwrapper::testing

tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,14 @@ inline auto smooth_matrix(std::size_t i = 10, std::size_t j = 10) {
3434
return tensorwrapper::shape::Smooth{i, j};
3535
}
3636

37-
inline auto smooth_tensor(std::size_t i = 10, std::size_t j = 10,
38-
std::size_t k = 10) {
37+
inline auto smooth_tensor3(std::size_t i = 10, std::size_t j = 10,
38+
std::size_t k = 10) {
3939
return tensorwrapper::shape::Smooth{i, j, k};
4040
}
4141

42+
inline auto smooth_tensor4(std::size_t i = 10, std::size_t j = 10,
43+
std::size_t k = 10, std::size_t l = 10) {
44+
return tensorwrapper::shape::Smooth{i, j, k, l};
45+
}
46+
4247
} // namespace tensorwrapper::testing

0 commit comments

Comments
 (0)