|
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#pragma once |
| 18 | + |
| 19 | +#include <cudf/utilities/error.hpp> |
| 20 | + |
| 21 | +#include <rmm/detail/error.hpp> |
| 22 | + |
| 23 | +#include <jni.h> |
| 24 | + |
| 25 | +namespace cudf::jni { |
| 26 | + |
| 27 | +// Wrapper for cudf JNI exception classes, which also store native stacktrace. |
| 28 | +constexpr char const* CUDA_EXCEPTION_CLASS = "ai/rapids/cudf/CudaException"; |
| 29 | +constexpr char const* CUDA_FATAL_EXCEPTION_CLASS = "ai/rapids/cudf/CudaFatalException"; |
| 30 | +constexpr char const* CUDF_EXCEPTION_CLASS = "ai/rapids/cudf/CudfException"; |
| 31 | +constexpr char const* CUDF_OVERFLOW_EXCEPTION_CLASS = |
| 32 | + "ai/rapids/cudf/CudfColumnSizeOverflowException"; |
| 33 | +constexpr char const* NVCOMP_EXCEPTION_CLASS = "ai/rapids/cudf/nvcomp/NvcompException"; |
| 34 | +constexpr char const* NVCOMP_CUDA_EXCEPTION_CLASS = "ai/rapids/cudf/nvcomp/NvcompCudaException"; |
| 35 | + |
| 36 | +// Java exceptions classes. |
| 37 | +constexpr char const* INDEX_OOB_EXCEPTION_CLASS = "java/lang/ArrayIndexOutOfBoundsException"; |
| 38 | +constexpr char const* ILLEGAL_ARG_EXCEPTION_CLASS = "java/lang/IllegalArgumentException"; |
| 39 | +constexpr char const* NPE_EXCEPTION_CLASS = "java/lang/NullPointerException"; |
| 40 | +constexpr char const* RUNTIME_EXCEPTION_CLASS = "java/lang/RuntimeException"; |
| 41 | +constexpr char const* UNSUPPORTED_EXCEPTION_CLASS = "java/lang/UnsupportedOperationException"; |
| 42 | + |
| 43 | +// Java error classes. |
| 44 | +// An error is a serious problem and the applications should not expect to recover from it. |
| 45 | +constexpr char const* OOM_ERROR_CLASS = "java/lang/OutOfMemoryError"; |
| 46 | + |
| 47 | +/** |
| 48 | + * @brief Exception class indicating that a JNI error of some kind was thrown and the main |
| 49 | + * function should return. |
| 50 | + */ |
| 51 | +class jni_exception : public std::runtime_error { |
| 52 | + public: |
| 53 | + jni_exception(char const* const message) : std::runtime_error(message) {} |
| 54 | + jni_exception(std::string const& message) : std::runtime_error(message) {} |
| 55 | +}; |
| 56 | + |
| 57 | +/** |
| 58 | + * @brief Throw a Java exception and a C++ one for flow control. |
| 59 | + */ |
| 60 | +inline void throw_java_exception(JNIEnv* const env, char const* class_name, char const* message) |
| 61 | +{ |
| 62 | + jclass ex_class = env->FindClass(class_name); |
| 63 | + if (ex_class != nullptr) { env->ThrowNew(ex_class, message); } |
| 64 | + throw jni_exception(message); |
| 65 | +} |
| 66 | + |
| 67 | +/** |
| 68 | + * @brief Check if a Java exceptions have been thrown and if so throw a C++ exception so the flow |
| 69 | + * control stop processing. |
| 70 | + */ |
| 71 | +inline void check_java_exception(JNIEnv* const env) |
| 72 | +{ |
| 73 | + if (env->ExceptionCheck()) { |
| 74 | + // Not going to try to get the message out of the Exception, too complex and |
| 75 | + // might fail. |
| 76 | + throw jni_exception("JNI Exception..."); |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +/** |
| 81 | + * @brief Create a cuda exception from a given cudaError_t. |
| 82 | + */ |
| 83 | +inline jthrowable cuda_exception(JNIEnv* const env, cudaError_t status, jthrowable cause = nullptr) |
| 84 | +{ |
| 85 | + // Calls cudaGetLastError twice. It is nearly certain that a fatal error occurred if the second |
| 86 | + // call doesn't return with cudaSuccess. |
| 87 | + cudaGetLastError(); |
| 88 | + auto const last = cudaGetLastError(); |
| 89 | + // Call cudaDeviceSynchronize to ensure `last` did not result from an asynchronous error |
| 90 | + // between two calls. |
| 91 | + auto const ex_class_name = status == last && last == cudaDeviceSynchronize() |
| 92 | + ? CUDA_FATAL_EXCEPTION_CLASS |
| 93 | + : CUDA_EXCEPTION_CLASS; |
| 94 | + |
| 95 | + jclass ex_class = env->FindClass(ex_class_name); |
| 96 | + if (ex_class == nullptr) { return nullptr; } |
| 97 | + jmethodID ctor_id = |
| 98 | + env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;ILjava/lang/Throwable;)V"); |
| 99 | + if (ctor_id == nullptr) { return nullptr; } |
| 100 | + |
| 101 | + jstring msg = env->NewStringUTF(cudaGetErrorString(status)); |
| 102 | + if (msg == nullptr) { return nullptr; } |
| 103 | + |
| 104 | + jint err_code = static_cast<jint>(status); |
| 105 | + |
| 106 | + jobject ret = env->NewObject(ex_class, ctor_id, msg, err_code, cause); |
| 107 | + return static_cast<jthrowable>(ret); |
| 108 | +} |
| 109 | + |
| 110 | +inline void jni_cuda_check(JNIEnv* const env, cudaError_t cuda_status) |
| 111 | +{ |
| 112 | + if (cudaSuccess != cuda_status) { |
| 113 | + jthrowable jt = cuda_exception(env, cuda_status); |
| 114 | + if (jt != nullptr) { env->Throw(jt); } |
| 115 | + throw jni_exception(std::string("CUDA ERROR: code ") + |
| 116 | + std::to_string(static_cast<int>(cuda_status))); |
| 117 | + } |
| 118 | +} |
| 119 | +} // namespace cudf::jni |
| 120 | + |
| 121 | +#define JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val) \ |
| 122 | + { \ |
| 123 | + if (env->ExceptionOccurred()) { return ret_val; } \ |
| 124 | + } |
| 125 | + |
| 126 | +#define JNI_THROW_NEW(env, class_name, message, ret_val) \ |
| 127 | + { \ |
| 128 | + jclass ex_class = env->FindClass(class_name); \ |
| 129 | + if (ex_class == nullptr) { return ret_val; } \ |
| 130 | + env->ThrowNew(ex_class, message); \ |
| 131 | + return ret_val; \ |
| 132 | + } |
| 133 | + |
| 134 | +// Throw a new exception only if one is not pending then always return with the specified value |
| 135 | +#define JNI_CHECK_THROW_CUDF_EXCEPTION(env, class_name, message, stacktrace, ret_val) \ |
| 136 | + { \ |
| 137 | + JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val); \ |
| 138 | + auto const ex_class = env->FindClass(class_name); \ |
| 139 | + if (ex_class == nullptr) { return ret_val; } \ |
| 140 | + auto const ctor_id = \ |
| 141 | + env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;Ljava/lang/String;)V"); \ |
| 142 | + if (ctor_id == nullptr) { return ret_val; } \ |
| 143 | + auto const empty_str = std::string{""}; \ |
| 144 | + auto const jmessage = env->NewStringUTF(message == nullptr ? empty_str.c_str() : message); \ |
| 145 | + if (jmessage == nullptr) { return ret_val; } \ |
| 146 | + auto const jstacktrace = \ |
| 147 | + env->NewStringUTF(stacktrace == nullptr ? empty_str.c_str() : stacktrace); \ |
| 148 | + if (jstacktrace == nullptr) { return ret_val; } \ |
| 149 | + auto const jobj = env->NewObject(ex_class, ctor_id, jmessage, jstacktrace); \ |
| 150 | + if (jobj == nullptr) { return ret_val; } \ |
| 151 | + env->Throw(reinterpret_cast<jthrowable>(jobj)); \ |
| 152 | + return ret_val; \ |
| 153 | + } |
| 154 | + |
| 155 | +// Throw a new exception only if one is not pending then always return with the specified value |
| 156 | +#define JNI_CHECK_THROW_CUDA_EXCEPTION(env, class_name, message, stacktrace, error_code, ret_val) \ |
| 157 | + { \ |
| 158 | + JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val); \ |
| 159 | + auto const ex_class = env->FindClass(class_name); \ |
| 160 | + if (ex_class == nullptr) { return ret_val; } \ |
| 161 | + auto const ctor_id = \ |
| 162 | + env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;Ljava/lang/String;I)V"); \ |
| 163 | + if (ctor_id == nullptr) { return ret_val; } \ |
| 164 | + auto const empty_str = std::string{""}; \ |
| 165 | + auto const jmessage = env->NewStringUTF(message == nullptr ? empty_str.c_str() : message); \ |
| 166 | + if (jmessage == nullptr) { return ret_val; } \ |
| 167 | + auto const jstacktrace = \ |
| 168 | + env->NewStringUTF(stacktrace == nullptr ? empty_str.c_str() : stacktrace); \ |
| 169 | + if (jstacktrace == nullptr) { return ret_val; } \ |
| 170 | + auto const jerror_code = static_cast<jint>(error_code); \ |
| 171 | + auto const jobj = env->NewObject(ex_class, ctor_id, jmessage, jstacktrace, jerror_code); \ |
| 172 | + if (jobj == nullptr) { return ret_val; } \ |
| 173 | + env->Throw(reinterpret_cast<jthrowable>(jobj)); \ |
| 174 | + return ret_val; \ |
| 175 | + } |
| 176 | + |
| 177 | +#define JNI_NULL_CHECK(env, obj, error_msg, ret_val) \ |
| 178 | + { \ |
| 179 | + if ((obj) == 0) { JNI_THROW_NEW(env, cudf::jni::NPE_EXCEPTION_CLASS, error_msg, ret_val); } \ |
| 180 | + } |
| 181 | + |
| 182 | +#define JNI_ARG_CHECK(env, obj, error_msg, ret_val) \ |
| 183 | + { \ |
| 184 | + if (!(obj)) { \ |
| 185 | + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, error_msg, ret_val); \ |
| 186 | + } \ |
| 187 | + } |
| 188 | + |
| 189 | +#define CATCH_STD_CLASS(env, class_name, ret_val) \ |
| 190 | + catch (const rmm::out_of_memory& e) \ |
| 191 | + { \ |
| 192 | + JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val); \ |
| 193 | + auto const what = \ |
| 194 | + std::string("Could not allocate native memory: ") + (e.what() == nullptr ? "" : e.what()); \ |
| 195 | + JNI_THROW_NEW(env, cudf::jni::OOM_ERROR_CLASS, what.c_str(), ret_val); \ |
| 196 | + } \ |
| 197 | + catch (const cudf::fatal_cuda_error& e) \ |
| 198 | + { \ |
| 199 | + JNI_CHECK_THROW_CUDA_EXCEPTION(env, \ |
| 200 | + cudf::jni::CUDA_FATAL_EXCEPTION_CLASS, \ |
| 201 | + e.what(), \ |
| 202 | + e.stacktrace(), \ |
| 203 | + e.error_code(), \ |
| 204 | + ret_val); \ |
| 205 | + } \ |
| 206 | + catch (const cudf::cuda_error& e) \ |
| 207 | + { \ |
| 208 | + JNI_CHECK_THROW_CUDA_EXCEPTION( \ |
| 209 | + env, cudf::jni::CUDA_EXCEPTION_CLASS, e.what(), e.stacktrace(), e.error_code(), ret_val); \ |
| 210 | + } \ |
| 211 | + catch (const cudf::data_type_error& e) \ |
| 212 | + { \ |
| 213 | + JNI_CHECK_THROW_CUDF_EXCEPTION( \ |
| 214 | + env, cudf::jni::CUDF_EXCEPTION_CLASS, e.what(), e.stacktrace(), ret_val); \ |
| 215 | + } \ |
| 216 | + catch (std::overflow_error const& e) \ |
| 217 | + { \ |
| 218 | + JNI_CHECK_THROW_CUDF_EXCEPTION(env, \ |
| 219 | + cudf::jni::CUDF_OVERFLOW_EXCEPTION_CLASS, \ |
| 220 | + e.what(), \ |
| 221 | + "No native stacktrace is available.", \ |
| 222 | + ret_val); \ |
| 223 | + } \ |
| 224 | + catch (const std::exception& e) \ |
| 225 | + { \ |
| 226 | + char const* stacktrace = "No native stacktrace is available."; \ |
| 227 | + if (auto const cudf_ex = dynamic_cast<cudf::logic_error const*>(&e); cudf_ex != nullptr) { \ |
| 228 | + stacktrace = cudf_ex->stacktrace(); \ |
| 229 | + } \ |
| 230 | + /* Double check whether the thrown exception is unrecoverable CUDA error or not. */ \ |
| 231 | + /* Like cudf::detail::throw_cuda_error, it is nearly certain that a fatal error */ \ |
| 232 | + /* occurred if the second call doesn't return with cudaSuccess. */ \ |
| 233 | + cudaGetLastError(); \ |
| 234 | + auto const last = cudaFree(0); \ |
| 235 | + if (cudaSuccess != last && last == cudaDeviceSynchronize()) { \ |
| 236 | + /* Throw CudaFatalException since the thrown exception is unrecoverable CUDA error */ \ |
| 237 | + JNI_CHECK_THROW_CUDA_EXCEPTION( \ |
| 238 | + env, cudf::jni::CUDA_FATAL_EXCEPTION_CLASS, e.what(), stacktrace, last, ret_val); \ |
| 239 | + } \ |
| 240 | + JNI_CHECK_THROW_CUDF_EXCEPTION(env, class_name, e.what(), stacktrace, ret_val); \ |
| 241 | + } |
| 242 | + |
| 243 | +#define CATCH_STD(env, ret_val) CATCH_STD_CLASS(env, cudf::jni::CUDF_EXCEPTION_CLASS, ret_val) |
0 commit comments