Skip to content

Commit 06d774e

Browse files
committed
eckit::linalg::sparse::LinearAlgebraTorch backend
1 parent 3d7f4fe commit 06d774e

File tree

5 files changed

+182
-0
lines changed

5 files changed

+182
-0
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,13 @@ set_package_properties( ViennaCL PROPERTIES
275275
TYPE RECOMMENDED
276276
PURPOSE "Dense and sparse matrix operations on OpenCL devices" )
277277

278+
### Torch
279+
280+
ecbuild_add_option( FEATURE TORCH
281+
DEFAULT OFF
282+
DESCRIPTION "Torch linear algebra operations"
283+
REQUIRED_PACKAGES Torch )
284+
278285
### LibRsync
279286

280287
ecbuild_add_option( FEATURE RSYNC

src/eckit/linalg/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ if( eckit_HAVE_MKL )
7575
list( APPEND eckit_la_plibs "${MKL_LIBRARIES}" )
7676
endif()
7777

78+
if( eckit_HAVE_TORCH )
79+
list( APPEND eckit_la_srcs
80+
sparse/LinearAlgebraTorch.cc
81+
sparse/LinearAlgebraTorch.h )
82+
list( APPEND eckit_la_pincludes torch )
83+
list( APPEND eckit_la_plibs torch )
84+
endif()
85+
7886
if( eckit_HAVE_VIENNACL )
7987
list( APPEND eckit_la_srcs
8088
dense/LinearAlgebraViennaCL.cc
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* (C) Copyright 1996- ECMWF.
3+
*
4+
* This software is licensed under the terms of the Apache Licence Version 2.0
5+
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6+
* In applying this licence, ECMWF does not waive the privileges and immunities
7+
* granted to it by virtue of its status as an intergovernmental organisation
8+
* nor does it submit to any jurisdiction.
9+
*/
10+
11+
12+
#include "eckit/linalg/sparse/LinearAlgebraTorch.h"
13+
14+
#include <cstring>
15+
#include <ostream>
16+
17+
#include "torch/torch.h"
18+
19+
#include "eckit/config/Resource.h"
20+
#include "eckit/exception/Exceptions.h"
21+
#include "eckit/linalg/Matrix.h"
22+
#include "eckit/linalg/SparseMatrix.h"
23+
#include "eckit/linalg/Vector.h"
24+
#include "eckit/linalg/sparse/LinearAlgebraGeneric.h"
25+
26+
27+
namespace eckit::linalg::sparse {
28+
29+
30+
static_assert(std::is_same<int32_t, Index>::value, "Index type mismatch");
31+
static_assert(std::is_same<double, Scalar>::value, "Scalar type mismatch");
32+
33+
static const LinearAlgebraTorch __la_torch("torch");
34+
35+
36+
namespace {
37+
38+
39+
torch::TensorOptions make_options(torch::ScalarType _dtype) {
40+
static const auto _device = [](const std::string& dev) {
41+
return dev == "cpu" ? torch::DeviceType::CPU
42+
: dev == "cuda" ? torch::DeviceType::CUDA
43+
: dev == "hip" ? torch::DeviceType::HIP
44+
: dev == "fpga" ? torch::DeviceType::FPGA
45+
: dev == "maia" ? torch::DeviceType::MAIA
46+
: dev == "xla" ? torch::DeviceType::XLA
47+
: dev == "mps" ? torch::DeviceType::MPS
48+
: dev == "meta" ? torch::DeviceType::Meta
49+
: dev == "vulkan" ? torch::DeviceType::Vulkan
50+
: dev == "metal" ? torch::DeviceType::Metal
51+
: dev == "xpu" ? torch::DeviceType::XPU
52+
: dev == "hpu" ? torch::DeviceType::HPU
53+
: dev == "ve" ? torch::DeviceType::VE
54+
: dev == "lazy" ? torch::DeviceType::Lazy
55+
: dev == "ipu" ? torch::DeviceType::IPU
56+
: dev == "mtia" ? torch::DeviceType::MTIA
57+
: NOTIMP;
58+
}(eckit::Resource<std::string>("$ECKIT_LINALG_TORCH_DEVICE;eckitLinalgTorchDevice;-eckitLinalgTorchDevice", "cpu"));
59+
60+
return torch::TensorOptions().dtype(_dtype).device(_device);
61+
}
62+
63+
64+
} // namespace
65+
66+
67+
void LinearAlgebraTorch::spmv(const SparseMatrix& A, const Vector& x, Vector& y) const {
68+
const auto options_int = make_options(torch::kInt32);
69+
const auto options_float = make_options(torch::kFloat64);
70+
71+
const auto Ni = static_cast<int32_t>(A.rows());
72+
const auto Nj = static_cast<int32_t>(A.cols());
73+
const auto Nz = static_cast<int32_t>(A.nonZeros());
74+
ASSERT(Ni == y.rows());
75+
ASSERT(Nj == x.rows());
76+
77+
// torch tensors
78+
const auto ia = torch::from_blob(const_cast<int32_t*>(A.outer()), Ni + 1, options_int);
79+
const auto ja = torch::from_blob(const_cast<int32_t*>(A.inner()), Nz, options_int);
80+
const auto a = torch::from_blob(const_cast<double*>(A.data()), Nz, options_float);
81+
82+
const auto A_tensor = torch::sparse_csr_tensor(ia, ja, a, {Ni, Nj}, options_float.layout(torch::kSparseCsr));
83+
84+
// multiplication
85+
const auto x_tensor = torch::from_blob(const_cast<double*>(x.data()), Nj, options_float);
86+
const auto y_tensor = torch::matmul(A_tensor, x_tensor);
87+
88+
// assignment
89+
std::memcpy(y.data(), y_tensor.data_ptr<double>(), Ni * sizeof(double));
90+
}
91+
92+
93+
void LinearAlgebraTorch::spmm(const SparseMatrix& A, const Matrix& X, Matrix& Y) const {
94+
const auto options_int = make_options(torch::kInt32);
95+
const auto options_float = make_options(torch::kFloat64);
96+
97+
const auto Ni = static_cast<int32_t>(A.rows());
98+
const auto Nj = static_cast<int32_t>(A.cols());
99+
const auto Nk = static_cast<int32_t>(X.cols());
100+
const auto Nz = static_cast<int32_t>(A.nonZeros());
101+
ASSERT(Ni == Y.rows());
102+
ASSERT(Nj == X.rows());
103+
ASSERT(Nk == Y.cols());
104+
105+
// torch tensors
106+
const auto ia = torch::from_blob(const_cast<int32_t*>(A.outer()), Ni + 1, options_int);
107+
const auto ja = torch::from_blob(const_cast<int32_t*>(A.inner()), Nz, options_int);
108+
const auto a = torch::from_blob(const_cast<double*>(A.data()), Nz, options_float);
109+
110+
const auto A_tensor = torch::sparse_csr_tensor(ia, ja, a, {Ni, Nj}, options_float.layout(torch::kSparseCsr));
111+
112+
// multiplication and conversion from column-major to row-major (and back)
113+
auto t = [](auto&& tensor) { return tensor.transpose(0, 1).contiguous(); };
114+
115+
const auto X_tensor = t(torch::from_blob(const_cast<double*>(X.data()), {Nk, Nj}, options_float));
116+
const auto Y_tensor = t(torch::matmul(A_tensor, X_tensor));
117+
118+
// assignment
119+
std::memcpy(Y.data(), Y_tensor.data_ptr<double>(), Y.size() * sizeof(double));
120+
}
121+
122+
123+
void LinearAlgebraTorch::dsptd(const Vector& x, const SparseMatrix& A, const Vector& y, SparseMatrix& B) const {
124+
static const sparse::LinearAlgebraGeneric generic;
125+
generic.dsptd(x, A, y, B);
126+
}
127+
128+
129+
void LinearAlgebraTorch::print(std::ostream& out) const {
130+
out << "LinearAlgebraTorch[]";
131+
}
132+
133+
134+
} // namespace eckit::linalg::sparse
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* (C) Copyright 1996- ECMWF.
3+
*
4+
* This software is licensed under the terms of the Apache Licence Version 2.0
5+
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6+
* In applying this licence, ECMWF does not waive the privileges and immunities
7+
* granted to it by virtue of its status as an intergovernmental organisation
8+
* nor does it submit to any jurisdiction.
9+
*/
10+
11+
12+
#pragma once
13+
14+
#include "eckit/linalg/LinearAlgebraSparse.h"
15+
16+
namespace eckit::linalg::sparse {
17+
18+
struct LinearAlgebraTorch final : public LinearAlgebraSparse {
19+
LinearAlgebraTorch() {}
20+
LinearAlgebraTorch(const std::string& name) : LinearAlgebraSparse(name) {}
21+
22+
void spmv(const SparseMatrix&, const Vector&, Vector&) const override;
23+
void spmm(const SparseMatrix&, const Matrix&, Matrix&) const override;
24+
void dsptd(const Vector&, const SparseMatrix&, const Vector&, SparseMatrix&) const override;
25+
void print(std::ostream&) const override;
26+
};
27+
28+
} // namespace eckit::linalg::sparse

tests/linalg/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ ecbuild_add_test( TARGET eckit_test_linalg_sparse_backend_openmp
154154
CONDITION eckit_HAVE_OMP
155155
ARGS --log_level=message -linearAlgebraSparseBackend openmp )
156156

157+
ecbuild_add_test( TARGET eckit_test_linalg_sparse_backend_torch
158+
COMMAND eckit_test_linalg_sparse_backend
159+
CONDITION eckit_HAVE_TORCH
160+
ARGS --log_level=message -linearAlgebraSparseBackend torch )
161+
157162
#
158163

159164
ecbuild_add_test( TARGET eckit_test_linalg_sparse

0 commit comments

Comments
 (0)