Skip to content

Commit 363f92d

Browse files
authored
Merge pull request #1196 from mcabbott/sortperm
Allow sorting of tuples of numbers
2 parents 55ed099 + 5a9611f commit 363f92d

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

src/sorting.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ Uses block y index to decide which values to operate on.
7373
sync_threads()
7474
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
7575
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
76-
val = idx0 <= hi ? values[idx0] : one(eltype(values))
77-
comparison = flex_lt(pivot, val, parity, lt, by)
76+
@inbounds if idx0 <= hi
77+
val = values[idx0]
78+
comparison = flex_lt(pivot, val, parity, lt, by)
79+
end
7880

7981
@inbounds if idx0 <= hi
8082
sums[threadIdx().x] = 1 & comparison
@@ -85,9 +87,11 @@ Uses block y index to decide which values to operate on.
8587

8688
cumsum!(sums)
8789

88-
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
89-
@inbounds if idx0 <= hi && dest_idx <= length(swap)
90-
swap[dest_idx] = val
90+
@inbounds if idx0 <= hi
91+
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
92+
if dest_idx <= length(swap)
93+
swap[dest_idx] = val
94+
end
9195
end
9296
sync_threads()
9397

@@ -180,10 +184,8 @@ Must only run on 1 SM.
180184
c = n_eff() - d
181185
to_move = min(b, c)
182186
sync_threads()
183-
swap = if threadIdx().x <= to_move
184-
vals[lo + a + threadIdx().x]
185-
else
186-
zero(eltype(vals)) # unused value
187+
if threadIdx().x <= to_move
188+
swap = vals[lo + a + threadIdx().x]
187189
end
188190
sync_threads()
189191
if threadIdx().x <= to_move
@@ -215,7 +217,6 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
215217

216218
@inbounds swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
217219
sync_threads()
218-
old_val = zero(eltype(swap))
219220

220221
log_blockDim = begin
221222
out = 0
@@ -269,10 +270,8 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
269270
for level in 0:L
270271
# get left/right neighbor depending on even/odd level
271272
buddy = threadIdx().x - 1i32 + 2i32 * (1i32 & (threadIdx().x % 2i32 != level % 2i32))
272-
buddy_val = if 1 <= buddy <= L && threadIdx().x <= L
273-
swap[buddy]
274-
else
275-
zero(eltype(swap)) # unused value
273+
if 1 <= buddy <= L && threadIdx().x <= L
274+
buddy_val = swap[buddy]
276275
end
277276
sync_threads()
278277
if 1 <= buddy <= L && threadIdx().x <= L
@@ -738,7 +737,7 @@ Each view is indexed along block x dim: one view per pseudo-block
738737
@inbounds swap[threadIdx().x, threadIdx().y] = vals[index+one(I)]
739738
end
740739
sync_threads()
741-
return @view swap[:, threadIdx().y]
740+
return @inbounds @view swap[:, threadIdx().y]
742741
end
743742

744743
"""

test/sorting.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ end
280280
end
281281
end
282282

283+
# XXX: some tests here make compute-sanitizer hang, but only on CI.
284+
# maybe related to the container set-up? try again once we use Sandbox.jl.
285+
283286
@testset "interface" begin
284287
@testset "quicksort" begin
285288
# pre-sorted
@@ -302,6 +305,7 @@ end
302305
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CUDA.QuickSort)
303306
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CUDA.QuickSort)
304307
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CUDA.QuickSort)
308+
@not_if_sanitize @test check_sort!(Tuple{Int,Int}, 10000, x -> (rand(Int), rand(Int)); alg=CUDA.QuickSort)
305309

306310
# non-uniform distributions
307311
@test check_sort!(UInt8, 100000, x -> round(255 * rand() ^ 2); alg=CUDA.QuickSort)
@@ -345,6 +349,7 @@ end
345349
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CUDA.BitonicSort)
346350
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CUDA.BitonicSort)
347351
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CUDA.BitonicSort)
352+
@not_if_sanitize @test check_sort!(Tuple{Int,Int}, 10000, x -> (rand(Int), rand(Int)); alg=CUDA.BitonicSort)
348353

349354
# test various sizes
350355
@test check_sort!(Float32, 1, x -> rand(Float32); alg=CUDA.BitonicSort)

0 commit comments

Comments
 (0)