Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .cudaq_version
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"cudaq": {
"repository": "NVIDIA/cuda-quantum",
"ref": "5b1000af08ad059c82af94f18782f06e0ecd12d3"
"ref": "887759d8152894b9ae4851d790724938eafdd149"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial testing looks great, so nothing required for now, but just noting that we will want to revert this change (or change it to a commit that is actually on cuda-quantum main branch) before actually merging this CUDA-QX PR.

}
}
202 changes: 0 additions & 202 deletions libs/core/include/cuda-qx/core/extension_point.h

This file was deleted.

1 change: 0 additions & 1 deletion libs/core/include/cuda-qx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
******************************************************************************/
#pragma once

#include "extension_point.h"
#include "tensor_impl.h"
#include "type_traits.h"

Expand Down
11 changes: 6 additions & 5 deletions libs/core/include/cuda-qx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,27 @@
******************************************************************************/
#pragma once

#include "extension_point.h"
#include "cudaq/utils/extension_point.h"

#include <complex>
#include <iostream>
#include <memory>
#include <numeric>
#include <vector>

namespace cudaqx::details {

/// @brief Implementation class for tensor operations following the PIMPL idiom
template <typename Scalar = std::complex<double>>
class tensor_impl
: public cudaqx::extension_point<tensor_impl<Scalar>, const Scalar *,
const std::vector<std::size_t>> {
: public cudaq::extension_point<tensor_impl<Scalar>, const Scalar *,
const std::vector<std::size_t>> {
public:
/// @brief Type alias for the scalar type used in the tensor
using scalar_type = Scalar;
using BaseExtensionPoint =
cudaqx::extension_point<tensor_impl<Scalar>, const Scalar *,
const std::vector<std::size_t>>;
cudaq::extension_point<tensor_impl<Scalar>, const Scalar *,
const std::vector<std::size_t>>;

/// @brief Create a tensor implementation with the given name and shape
/// @param name The name of the tensor implementation
Expand Down
2 changes: 2 additions & 0 deletions libs/core/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ target_include_directories(cudaqx-core
)

target_link_libraries(cudaqx-core
PUBLIC
cudaq::cudaq
PRIVATE
xtensor
xtensor-blas
Expand Down
9 changes: 6 additions & 3 deletions libs/core/lib/tensor_impls/xtensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class xtensor : public cudaqx::details::tensor_impl<Scalar> {
/// @param d Pointer to the tensor data
/// @param s Shape of the tensor
/// @return A unique pointer to the created xtensor object
CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION_WITH_NAME(
CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION_WITH_NAME(
xtensor<Scalar>, std::string("xtensor") + std::string(ScalarAsString),
static std::unique_ptr<cudaqx::details::tensor_impl<Scalar>> create(
const Scalar *d, const std::vector<std::size_t> s) {
Expand All @@ -310,11 +310,13 @@ class xtensor : public cudaqx::details::tensor_impl<Scalar> {
}
};

} // namespace cudaqx

/// @brief Register the xtensor types

#define INSTANTIATE_REGISTRY_TENSOR_IMPL(TYPE) \
INSTANTIATE_REGISTRY(cudaqx::details::tensor_impl<TYPE>, const TYPE *, \
const std::vector<std::size_t>)
CUDAQ_INSTANTIATE_REGISTRY(cudaqx::details::tensor_impl<TYPE>, const TYPE *, \
const std::vector<std::size_t>)

INSTANTIATE_REGISTRY_TENSOR_IMPL(std::complex<double>)
INSTANTIATE_REGISTRY_TENSOR_IMPL(std::complex<float>)
Expand All @@ -324,6 +326,7 @@ INSTANTIATE_REGISTRY_TENSOR_IMPL(double)
INSTANTIATE_REGISTRY_TENSOR_IMPL(float)
INSTANTIATE_REGISTRY_TENSOR_IMPL(std::size_t)

namespace cudaqx {
template <>
const bool xtensor<std::complex<double>>::registered_ =
xtensor<std::complex<double>>::register_type();
Expand Down
Loading
Loading