Skip to content

Commit e04b7ae

Browse files
authored
Fix a non-determinism bug in CUDA (risc0#3451)
1 parent 8b56ddb commit e04b7ae

File tree

6 files changed

+52
-17
lines changed

6 files changed

+52
-17
lines changed

risc0/circuit/rv32im-m3-sys/cxx/hal/cuda/hal.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ extern "C" bool cuda_zero_dev(void* buf, size_t size);
4646
// api.cu
4747
extern "C" SparkError sppark_poseidon2_fold(void* d_out, const void* d_in, size_t num_hashes);
4848
extern "C" SparkError sppark_poseidon2_rows(void* d_out, const void* d_in, uint32_t count, uint32_t col_size);
49-
extern "C" SparkError sppark_prefix_sum(void* d_inout, uint32_t count);
49+
extern "C" void prefix_sum(Fp* d_inout, uint32_t count);
5050
extern "C" SparkError supra_poly_divide(void* d_inout, size_t len, void* remainder, FpExt pow);
5151

5252
// query.cu
@@ -290,10 +290,7 @@ class CudaHal : public IHal {
290290
size_t po2 = checkPo2(accum.rows());
291291
accum_witgen_cuda(toDevPtr(accum), toDevPtr(data), toDevPtr(globals), toDevPtr(accMix), risc0::ROU_FWD[po2], accum.rows());
292292
for (size_t i = 0; i < 4; i++) {
293-
auto err = sppark_prefix_sum(toDevPtr(accum) + accum.rows() * i, accum.rows());
294-
if (err.code != 0) {
295-
throw std::runtime_error(std::string("Error during computeAccumWitness:") + err.message);
296-
}
293+
prefix_sum(toDevPtr(accum) + accum.rows() * i, accum.rows());
297294
}
298295
}
299296

risc0/circuit/rv32im-m3-sys/cxx/hal/cuda/kernels/api.cu

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#include "poseidon2.cuh"
2626
//#include "poseidon254.cuh"
2727

28+
#include <thrust/execution_policy.h>
29+
#include <thrust/scan.h>
30+
2831
extern "C" RustError::by_value
2932
sppark_poseidon2_fold(poseidon_out_t* d_out, const poseidon_in_t* d_in, size_t num_hashes) {
3033
const gpu_t& gpu = select_gpu();
@@ -153,18 +156,8 @@ sppark_poseidon254_rows(alt_bn128::fr_t* d_out, const fr_t* d_in, size_t count,
153156

154157
#endif
155158

156-
extern "C" RustError::by_value sppark_prefix_sum(fr_t d_inout[/*count*/], uint32_t count) {
157-
const gpu_t& gpu = select_gpu();
158-
159-
try {
160-
prefix_op<Add<fr_t>>(d_inout, count, gpu);
161-
gpu.sync();
162-
} catch (const cuda_error& e) {
163-
gpu.sync();
164-
return RustError{e.code(), e.what()};
165-
}
166-
167-
return RustError{cudaSuccess};
159+
extern "C" void prefix_sum(fr_t* buf, uint32_t count) {
160+
thrust::inclusive_scan(thrust::device, buf, buf + count, buf);
168161
}
169162

170163
extern "C" RustError::by_value

risc0/circuit/rv32im-m3-sys/cxx/rv32im/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ cc_library(
66
deps = [
77
"//prove",
88
"//verify",
9+
"//hal/pick",
910
],
1011
)

risc0/circuit/rv32im-m3-sys/cxx/rv32im/ffi.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ const char* risc0_circuit_rv32im_m3_prove(const uint8_t* elf_ptr, size_t elf_len
4848
verifyRv32im(readIop, po2);
4949
readIop.done();
5050
} catch (const std::exception& err) {
51+
LOG(0, "ERROR: " << err.what());
5152
return strdup(err.what());
5253
} catch (...) {
54+
LOG(0, "UNKNOWN ERROR");
5355
return strdup("Generic exception");
5456
}
57+
LOG(0, "Completed successfuly");
5558
return nullptr;
5659
}
5760

risc0/circuit/rv32im-m3-sys/cxx/rv32im/test/BUILD.bazel

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,10 @@ cc_test(
7474
],
7575
deps = ["//rv32im/test:test_prove"],
7676
)
77+
78+
cc_test(
79+
name = "test_ffi",
80+
srcs = ["test_ffi.cpp"],
81+
data = ["//rv32im/rvtest:riscv_test_bins"],
82+
deps = ["//rv32im"],
83+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2025 RISC Zero, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4+
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5+
// http://opensource.org/licenses/MIT>, at your option. This file may not be
6+
// copied, modified, or distributed except according to those terms.
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
//
14+
// SPDX-License-Identifier: Apache-2.0 OR MIT
15+
16+
#include "core/log.h"
17+
#include "core/util.h"
18+
19+
extern "C" const char* risc0_circuit_rv32im_m3_prove(const uint8_t* elf_ptr, size_t elf_len);
20+
21+
void runTest(const std::string& name) {
22+
auto fullname = "rv32im/rvtest/" + name;
23+
auto elf = risc0::loadFile(fullname);
24+
const char* err = risc0_circuit_rv32im_m3_prove(elf.data(), elf.size());
25+
if (err != nullptr) {
26+
throw std::runtime_error(err);
27+
}
28+
}
29+
30+
int main() {
31+
LOG(0, "Hello world");
32+
runTest("add");
33+
runTest("addi");
34+
}

0 commit comments

Comments
 (0)