Skip to content

Commit 9f92842

Browse files
authored
Refactor JNI error handling (#18983)
This performs some changes to Java JNI, extracting the error-handling code into a separate header along with some simple cleanup. Such modification paves the way for capturing native exception with stacktrace in the follow-up work. No new feature/implementation is added. Only code moving around. Since some variable names (MACROs for strings) are changed, this will break `spark-rapids-jni` build thus a fix for it (NVIDIA/spark-rapids-jni#3375) should be merged right after this PR. This is part of [[Epic] Capture native stacktrace when throwing exception using cpptrace #3398](NVIDIA/spark-rapids-jni#3398). Authors: - Nghia Truong (https://github.com/ttnghia) Approvers: - Alessandro Bellina (https://github.com/abellina) URL: #18983
1 parent 4474202 commit 9f92842

File tree

11 files changed

+418
-345
lines changed

11 files changed

+418
-345
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include <cudf/utilities/error.hpp>
20+
21+
#include <rmm/detail/error.hpp>
22+
23+
#include <jni.h>
24+
25+
namespace cudf::jni {
26+
27+
// Wrapper for cudf JNI exception classes, which also store native stacktrace.
28+
constexpr char const* CUDA_EXCEPTION_CLASS = "ai/rapids/cudf/CudaException";
29+
constexpr char const* CUDA_FATAL_EXCEPTION_CLASS = "ai/rapids/cudf/CudaFatalException";
30+
constexpr char const* CUDF_EXCEPTION_CLASS = "ai/rapids/cudf/CudfException";
31+
constexpr char const* CUDF_OVERFLOW_EXCEPTION_CLASS =
32+
"ai/rapids/cudf/CudfColumnSizeOverflowException";
33+
constexpr char const* NVCOMP_EXCEPTION_CLASS = "ai/rapids/cudf/nvcomp/NvcompException";
34+
constexpr char const* NVCOMP_CUDA_EXCEPTION_CLASS = "ai/rapids/cudf/nvcomp/NvcompCudaException";
35+
36+
// Java exceptions classes.
37+
constexpr char const* INDEX_OOB_EXCEPTION_CLASS = "java/lang/ArrayIndexOutOfBoundsException";
38+
constexpr char const* ILLEGAL_ARG_EXCEPTION_CLASS = "java/lang/IllegalArgumentException";
39+
constexpr char const* NPE_EXCEPTION_CLASS = "java/lang/NullPointerException";
40+
constexpr char const* RUNTIME_EXCEPTION_CLASS = "java/lang/RuntimeException";
41+
constexpr char const* UNSUPPORTED_EXCEPTION_CLASS = "java/lang/UnsupportedOperationException";
42+
43+
// Java error classes.
44+
// An error is a serious problem and the applications should not expect to recover from it.
45+
constexpr char const* OOM_ERROR_CLASS = "java/lang/OutOfMemoryError";
46+
47+
/**
48+
* @brief Exception class indicating that a JNI error of some kind was thrown and the main
49+
* function should return.
50+
*/
51+
class jni_exception : public std::runtime_error {
52+
public:
53+
jni_exception(char const* const message) : std::runtime_error(message) {}
54+
jni_exception(std::string const& message) : std::runtime_error(message) {}
55+
};
56+
57+
/**
58+
* @brief Throw a Java exception and a C++ one for flow control.
59+
*/
60+
inline void throw_java_exception(JNIEnv* const env, char const* class_name, char const* message)
61+
{
62+
jclass ex_class = env->FindClass(class_name);
63+
if (ex_class != nullptr) { env->ThrowNew(ex_class, message); }
64+
throw jni_exception(message);
65+
}
66+
67+
/**
68+
* @brief Check if a Java exceptions have been thrown and if so throw a C++ exception so the flow
69+
* control stop processing.
70+
*/
71+
inline void check_java_exception(JNIEnv* const env)
72+
{
73+
if (env->ExceptionCheck()) {
74+
// Not going to try to get the message out of the Exception, too complex and
75+
// might fail.
76+
throw jni_exception("JNI Exception...");
77+
}
78+
}
79+
80+
/**
81+
* @brief Create a cuda exception from a given cudaError_t.
82+
*/
83+
inline jthrowable cuda_exception(JNIEnv* const env, cudaError_t status, jthrowable cause = nullptr)
84+
{
85+
// Calls cudaGetLastError twice. It is nearly certain that a fatal error occurred if the second
86+
// call doesn't return with cudaSuccess.
87+
cudaGetLastError();
88+
auto const last = cudaGetLastError();
89+
// Call cudaDeviceSynchronize to ensure `last` did not result from an asynchronous error
90+
// between two calls.
91+
auto const ex_class_name = status == last && last == cudaDeviceSynchronize()
92+
? CUDA_FATAL_EXCEPTION_CLASS
93+
: CUDA_EXCEPTION_CLASS;
94+
95+
jclass ex_class = env->FindClass(ex_class_name);
96+
if (ex_class == nullptr) { return nullptr; }
97+
jmethodID ctor_id =
98+
env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;ILjava/lang/Throwable;)V");
99+
if (ctor_id == nullptr) { return nullptr; }
100+
101+
jstring msg = env->NewStringUTF(cudaGetErrorString(status));
102+
if (msg == nullptr) { return nullptr; }
103+
104+
jint err_code = static_cast<jint>(status);
105+
106+
jobject ret = env->NewObject(ex_class, ctor_id, msg, err_code, cause);
107+
return static_cast<jthrowable>(ret);
108+
}
109+
110+
inline void jni_cuda_check(JNIEnv* const env, cudaError_t cuda_status)
111+
{
112+
if (cudaSuccess != cuda_status) {
113+
jthrowable jt = cuda_exception(env, cuda_status);
114+
if (jt != nullptr) { env->Throw(jt); }
115+
throw jni_exception(std::string("CUDA ERROR: code ") +
116+
std::to_string(static_cast<int>(cuda_status)));
117+
}
118+
}
119+
} // namespace cudf::jni
120+
121+
#define JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val) \
122+
{ \
123+
if (env->ExceptionOccurred()) { return ret_val; } \
124+
}
125+
126+
#define JNI_THROW_NEW(env, class_name, message, ret_val) \
127+
{ \
128+
jclass ex_class = env->FindClass(class_name); \
129+
if (ex_class == nullptr) { return ret_val; } \
130+
env->ThrowNew(ex_class, message); \
131+
return ret_val; \
132+
}
133+
134+
// Throw a new exception only if one is not pending then always return with the specified value
135+
#define JNI_CHECK_THROW_CUDF_EXCEPTION(env, class_name, message, stacktrace, ret_val) \
136+
{ \
137+
JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val); \
138+
auto const ex_class = env->FindClass(class_name); \
139+
if (ex_class == nullptr) { return ret_val; } \
140+
auto const ctor_id = \
141+
env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;Ljava/lang/String;)V"); \
142+
if (ctor_id == nullptr) { return ret_val; } \
143+
auto const empty_str = std::string{""}; \
144+
auto const jmessage = env->NewStringUTF(message == nullptr ? empty_str.c_str() : message); \
145+
if (jmessage == nullptr) { return ret_val; } \
146+
auto const jstacktrace = \
147+
env->NewStringUTF(stacktrace == nullptr ? empty_str.c_str() : stacktrace); \
148+
if (jstacktrace == nullptr) { return ret_val; } \
149+
auto const jobj = env->NewObject(ex_class, ctor_id, jmessage, jstacktrace); \
150+
if (jobj == nullptr) { return ret_val; } \
151+
env->Throw(reinterpret_cast<jthrowable>(jobj)); \
152+
return ret_val; \
153+
}
154+
155+
// Throw a new exception only if one is not pending then always return with the specified value
156+
#define JNI_CHECK_THROW_CUDA_EXCEPTION(env, class_name, message, stacktrace, error_code, ret_val) \
157+
{ \
158+
JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val); \
159+
auto const ex_class = env->FindClass(class_name); \
160+
if (ex_class == nullptr) { return ret_val; } \
161+
auto const ctor_id = \
162+
env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;Ljava/lang/String;I)V"); \
163+
if (ctor_id == nullptr) { return ret_val; } \
164+
auto const empty_str = std::string{""}; \
165+
auto const jmessage = env->NewStringUTF(message == nullptr ? empty_str.c_str() : message); \
166+
if (jmessage == nullptr) { return ret_val; } \
167+
auto const jstacktrace = \
168+
env->NewStringUTF(stacktrace == nullptr ? empty_str.c_str() : stacktrace); \
169+
if (jstacktrace == nullptr) { return ret_val; } \
170+
auto const jerror_code = static_cast<jint>(error_code); \
171+
auto const jobj = env->NewObject(ex_class, ctor_id, jmessage, jstacktrace, jerror_code); \
172+
if (jobj == nullptr) { return ret_val; } \
173+
env->Throw(reinterpret_cast<jthrowable>(jobj)); \
174+
return ret_val; \
175+
}
176+
177+
#define JNI_NULL_CHECK(env, obj, error_msg, ret_val) \
178+
{ \
179+
if ((obj) == 0) { JNI_THROW_NEW(env, cudf::jni::NPE_EXCEPTION_CLASS, error_msg, ret_val); } \
180+
}
181+
182+
#define JNI_ARG_CHECK(env, obj, error_msg, ret_val) \
183+
{ \
184+
if (!(obj)) { \
185+
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, error_msg, ret_val); \
186+
} \
187+
}
188+
189+
#define CATCH_STD_CLASS(env, class_name, ret_val) \
190+
catch (const rmm::out_of_memory& e) \
191+
{ \
192+
JNI_EXCEPTION_OCCURRED_CHECK(env, ret_val); \
193+
auto const what = \
194+
std::string("Could not allocate native memory: ") + (e.what() == nullptr ? "" : e.what()); \
195+
JNI_THROW_NEW(env, cudf::jni::OOM_ERROR_CLASS, what.c_str(), ret_val); \
196+
} \
197+
catch (const cudf::fatal_cuda_error& e) \
198+
{ \
199+
JNI_CHECK_THROW_CUDA_EXCEPTION(env, \
200+
cudf::jni::CUDA_FATAL_EXCEPTION_CLASS, \
201+
e.what(), \
202+
e.stacktrace(), \
203+
e.error_code(), \
204+
ret_val); \
205+
} \
206+
catch (const cudf::cuda_error& e) \
207+
{ \
208+
JNI_CHECK_THROW_CUDA_EXCEPTION( \
209+
env, cudf::jni::CUDA_EXCEPTION_CLASS, e.what(), e.stacktrace(), e.error_code(), ret_val); \
210+
} \
211+
catch (const cudf::data_type_error& e) \
212+
{ \
213+
JNI_CHECK_THROW_CUDF_EXCEPTION( \
214+
env, cudf::jni::CUDF_EXCEPTION_CLASS, e.what(), e.stacktrace(), ret_val); \
215+
} \
216+
catch (std::overflow_error const& e) \
217+
{ \
218+
JNI_CHECK_THROW_CUDF_EXCEPTION(env, \
219+
cudf::jni::CUDF_OVERFLOW_EXCEPTION_CLASS, \
220+
e.what(), \
221+
"No native stacktrace is available.", \
222+
ret_val); \
223+
} \
224+
catch (const std::exception& e) \
225+
{ \
226+
char const* stacktrace = "No native stacktrace is available."; \
227+
if (auto const cudf_ex = dynamic_cast<cudf::logic_error const*>(&e); cudf_ex != nullptr) { \
228+
stacktrace = cudf_ex->stacktrace(); \
229+
} \
230+
/* Double check whether the thrown exception is unrecoverable CUDA error or not. */ \
231+
/* Like cudf::detail::throw_cuda_error, it is nearly certain that a fatal error */ \
232+
/* occurred if the second call doesn't return with cudaSuccess. */ \
233+
cudaGetLastError(); \
234+
auto const last = cudaFree(0); \
235+
if (cudaSuccess != last && last == cudaDeviceSynchronize()) { \
236+
/* Throw CudaFatalException since the thrown exception is unrecoverable CUDA error */ \
237+
JNI_CHECK_THROW_CUDA_EXCEPTION( \
238+
env, cudf::jni::CUDA_FATAL_EXCEPTION_CLASS, e.what(), stacktrace, last, ret_val); \
239+
} \
240+
JNI_CHECK_THROW_CUDF_EXCEPTION(env, class_name, e.what(), stacktrace, ret_val); \
241+
}
242+
243+
#define CATCH_STD(env, ret_val) CATCH_STD_CLASS(env, cudf::jni::CUDF_EXCEPTION_CLASS, ret_val)

0 commit comments

Comments
 (0)