Skip to content

Commit 8418021

Browse files
committed
introduced is_contiguous(Range) + use range_traits instead of directly reading Range::order
1 parent cefac13 commit 8418021

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

btas/optimize/contract.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ template<typename _T, class _TensorA, class _TensorB, class _TensorC,
1717
void contract_211(const _T& alpha, const _TensorA& A, const btas::DEFAULT::index<_UA>& aA, const _TensorB& B, const btas::DEFAULT::index<_UB>& aB,
1818
const _T& beta, _TensorC& C, const btas::DEFAULT::index<_UC>& aC, const bool conjgA, const bool conjgB) {
1919
assert(aA.size() == 2 && aB.size() == 1 && aC.size() == 1);
20-
assert(A.range().ordinal().contiguous() && B.range().ordinal().contiguous() && C.range().ordinal().contiguous());
20+
assert(is_contiguous(A.range()) && is_contiguous(B.range()) && is_contiguous(C.range()));
2121
if (conjgB) throw std::logic_error("complex conjugation of 1-index tensors is not considered in contract_211");
2222

2323
const bool notrans = aB[0] == aA[1];
@@ -36,7 +36,7 @@ void contract_222(const _T& alpha, const _TensorA& A, const btas::DEFAULT::index
3636
const _T& beta, _TensorC& C, const btas::DEFAULT::index<_UC>& aC, const bool conjgA, const bool conjgB) {
3737
// TODO we do not consider complex matrices yet.
3838
assert(aA.size() == 2 && aB.size() == 2 && aC.size() == 2);
39-
assert(A.range().ordinal().contiguous() && B.range().ordinal().contiguous() && C.range().ordinal().contiguous());
39+
assert(is_contiguous(A.range()) && is_contiguous(B.range()) && is_contiguous(C.range()));
4040
if (std::find(aA.begin(), aA.end(), aC.front()) != aA.end()) {
4141
// then multiply A * B -> C
4242
const bool notransA = aA.front() == aC.front();
@@ -65,10 +65,10 @@ template<typename _T, class _TensorA, class _TensorB, class _TensorC,
6565
void contract_323(const _T& alpha, const _TensorA& A, const btas::DEFAULT::index<_UA>& aA, const _TensorB& B, const btas::DEFAULT::index<_UB>& aB,
6666
const _T& beta, _TensorC& C, const btas::DEFAULT::index<_UC>& aC, const bool conjgA, const bool conjgB) {
6767
assert(aA.size() == 3 && aB.size() == 2 && aC.size() == 3);
68-
assert(A.range().ordinal().contiguous() && B.range().ordinal().contiguous() && C.range().ordinal().contiguous());
68+
assert(is_contiguous(A.range()) && is_contiguous(B.range()) && is_contiguous(C.range()));
6969
if (conjgA) throw std::logic_error("complex conjugation of 3-index tensors is not considered in contract_323");
7070

71-
// TODO this function is limited to special cases where one of three indices of A will be replaced in C. Permuation is not considered so far.
71+
// TODO this function is limited to special cases where one of three indices of A will be replaced in C. Permutation is not considered so far.
7272
// first idenfity which indices to be rotated
7373
int irot = -1;
7474
for (int i = 0; i != 3; ++i)
@@ -128,7 +128,7 @@ template<typename _T, class _TensorA, class _TensorB, class _TensorC,
128128
void contract_332(const _T& alpha, const _TensorA& A, const btas::DEFAULT::index<_UA>& aA, const _TensorB& B, const btas::DEFAULT::index<_UB>& aB,
129129
const _T& beta, _TensorC& C, const btas::DEFAULT::index<_UC>& aC, const bool conjgA, const bool conjgB) {
130130
assert(aA.size() == 3 && aB.size() == 3 && aC.size() == 2);
131-
assert(A.range().ordinal().contiguous() && B.range().ordinal().contiguous() && C.range().ordinal().contiguous());
131+
assert(is_contiguous(A.range()) && is_contiguous(B.range()) && is_contiguous(C.range()));
132132

133133
const bool back2 = aA[0] == aB[0] && aA[1] == aB[1];
134134
const bool front2 = aA[1] == aB[1] && aA[2] == aB[2];

btas/range.h

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,12 +1181,22 @@ namespace btas {
11811181
std::rbegin(r2_extent));
11821182
}
11831183

1184+
/// Tests whether a range is contiguous, i.e. whether its ordinal values form a contiguous range
1185+
1186+
/// \param range a Range
1187+
/// \return true if \p range is contiguous
1188+
template <CBLAS_ORDER _Order,
1189+
typename _Index,
1190+
typename _Ordinal>
1191+
inline bool is_contiguous(const RangeNd<_Order, _Index, _Ordinal>& range) {
1192+
return range.ordinal().contiguous();
1193+
}
1194+
11841195
/// Permutes a Range
11851196

1186-
/// permutes the axes using permutation \c p={p[0],p[1],...} specified in the preimage ("from") convention;
1187-
/// for example, after this call \c lobound()[p[i]] will return the value originally
1188-
/// returned by \c lobound()[i]
1189-
/// \param perm a sequence specifying from-permutation of the axes
1197+
/// permutes the dimensions using permutation \c p = {p[0], p[1], ... }; for example, if \c lobound() initially returned
1198+
/// {lb[0], lb[1], ... }, after this call \c lobound() will return {lb[p[0]], lb[p[1]], ...}.
1199+
/// \param perm an array specifying permutation of the dimensions
11901200
template <CBLAS_ORDER _Order,
11911201
typename _Index,
11921202
typename _Ordinal,
@@ -1216,10 +1226,9 @@ namespace btas {
12161226

12171227
/// Permutes a Range
12181228

1219-
/// permutes the axes using permutation \c p={p[0],p[1],...} specified in the preimage ("from") convention;
1220-
/// for example, after this call \c lobound()[p[i]] will return the value originally
1221-
/// returned by \c lobound()[i]
1222-
/// \param perm an initializer list specifying from-permutation of the axes
1229+
/// permutes the axes using permutation \c p = {p[0], p[1], ... }; for example, if \c lobound() initially returned
1230+
/// {lb[0], lb[1], ... }, after this call \c lobound() will return {lb[p[0]], lb[p[1]], ...} .
1231+
/// \param perm an array specifying permutation of the axes
12231232
template <CBLAS_ORDER _Order,
12241233
typename _Index,
12251234
typename _Ordinal,

btas/tensor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -736,9 +736,9 @@ namespace btas {
736736
bool operator==(const _Tensor1& t1, const _Tensor2& t2) {
737737
using std::cbegin;
738738
using std::cend;
739-
if (t1.range().order == t2.range().order &&
740-
t1.range().ordinal().contiguous() &&
741-
t2.range().ordinal().contiguous()) // plain Tensor
739+
if (btas::range_traits<std::decay_t<decltype(t1.range())>>::order == btas::range_traits<std::decay_t<decltype(t2.range())>>::order &&
740+
is_contiguous(t1.range()) &&
741+
is_contiguous(t2.range())) // plain Tensor
742742
return congruent(t1.range(), t2.range()) && std::equal(cbegin(t1.storage()),
743743
cend(t1.storage()),
744744
cbegin(t2.storage()));

0 commit comments

Comments
 (0)