Skip to content

Commit 6edf8ac

Browse files
committed
WIP: Refactor memory hierarchy.
1 parent abc5009 commit 6edf8ac

File tree

13 files changed

+50
-165
lines changed

13 files changed

+50
-165
lines changed

LocalPreferences.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[OpenCL]
22
# Which memory back-end to use for unspecified CLArray allocations. This can be:
33
# - "usm": Unified Shared Memory (`cl_intel_unified_shared_memory`)
4-
# - "bda": Buffer Device Address (`cl_mem` + `cl_ext_buffer_device_address`)
4+
# - "bda": plain buffers (`cl_mem` + `cl_ext_buffer_device_address`)
55
# - "svm": Shared Virtual Memory (coarse-grained)
66
# If unspecified, the default will be used based on the platform and device capabilities.
77
#default_memory_backend="..."

lib/cl/CL.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ include("device.jl")
2020
include("context.jl")
2121
include("cmdqueue.jl")
2222
include("event.jl")
23-
include("buffer.jl")
24-
include("memory/memory.jl")
23+
include("memory.jl")
2524
include("program.jl")
2625
include("kernel.jl")
2726

lib/cl/kernel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function set_arg!(k::Kernel, idx::Integer, arg::AbstractMemory)
7979
clSetKernelArgSVMPointer(k, idx - 1, pointer(arg))
8080
elseif arg isa UnifiedMemory
8181
clSetKernelArgMemPointerINTEL(k, idx - 1, pointer(arg))
82-
elseif arg isa BufferDeviceMemory
82+
elseif arg isa Buffer
8383
clSetKernelArgDevicePointerEXT(k, idx - 1, pointer(arg))
8484
else
8585
error("Unknown memory type")
@@ -203,7 +203,7 @@ function call(
203203

204204
if memory isa SharedVirtualMemory
205205
push!(svm_pointers, ptr)
206-
elseif memory isa BufferDeviceMemory
206+
elseif memory isa Buffer
207207
push!(bda_pointers, ptr)
208208
elseif memory isa UnifiedDeviceMemory
209209
device_access = true

lib/cl/memory/bda.jl

Lines changed: 0 additions & 33 deletions
This file was deleted.

lib/cl/buffer.jl renamed to lib/cl/memory/buffer.jl

Lines changed: 16 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,17 @@
1-
# OpenCL Memory Object
2-
3-
abstract type AbstractMemoryObject <: CLObject end
4-
5-
#This should be implemented by all subtypes
6-
# type MemoryType <: AbstractMemoryObject
7-
# id::cl_mem
8-
# ...
9-
# end
10-
11-
# for passing buffers to OpenCL APIs: use the underlying handle
12-
Base.unsafe_convert(::Type{cl_mem}, mem::AbstractMemoryObject) = mem.id
13-
14-
# for passing buffers to kernels: keep the buffer, it's handled by `cl.set_arg!`
15-
Base.unsafe_convert(::Type{<:Ptr}, mem::AbstractMemoryObject) = mem
16-
17-
Base.sizeof(mem::AbstractMemoryObject) = mem.size
18-
19-
release(mem::AbstractMemoryObject) = clReleaseMemObject(mem)
20-
21-
function Base.getproperty(mem::AbstractMemoryObject, s::Symbol)
22-
if s == :context
23-
param = Ref{cl_context}()
24-
clGetMemObjectInfo(mem, CL_MEM_CONTEXT, sizeof(cl_context), param, C_NULL)
25-
return Context(param[], retain = true)
26-
elseif s == :mem_type
27-
result = Ref{cl_mem_object_type}()
28-
clGetMemObjectInfo(mem, CL_MEM_TYPE, sizeof(cl_mem_object_type), result, C_NULL)
29-
return result[]
30-
elseif s == :mem_flags
31-
result = Ref{cl_mem_flags}()
32-
clGetMemObjectInfo(mem, CL_MEM_FLAGS, sizeof(cl_mem_flags), result, C_NULL)
33-
mf = result[]
34-
flags = Symbol[]
35-
if (mf & CL_MEM_READ_WRITE) != 0
36-
push!(flags, :rw)
37-
end
38-
if (mf & CL_MEM_WRITE_ONLY) != 0
39-
push!(flags, :w)
40-
end
41-
if (mf & CL_MEM_READ_ONLY) != 0
42-
push!(flags, :r)
43-
end
44-
if (mf & CL_MEM_USE_HOST_PTR) != 0
45-
push!(flags, :use)
46-
end
47-
if (mf & CL_MEM_ALLOC_HOST_PTR) != 0
48-
push!(flags, :alloc)
49-
end
50-
if (mf & CL_MEM_COPY_HOST_PTR) != 0
51-
push!(flags, :copy)
52-
end
53-
return tuple(flags...)
54-
elseif s == :size
55-
result = Ref{Csize_t}()
56-
clGetMemObjectInfo(mem, CL_MEM_SIZE, sizeof(Csize_t), result, C_NULL)
57-
return result[]
58-
elseif s == :reference_count
59-
result = Ref{Cuint}()
60-
clGetMemObjectInfo(mem, CL_MEM_REFERENCE_COUNT, sizeof(Cuint), result, C_NULL)
61-
return Int(result[])
62-
elseif s == :map_count
63-
result = Ref{Cuint}()
64-
clGetMemObjectInfo(mem, CL_MEM_MAP_COUNT, sizeof(Cuint), result, C_NULL)
65-
return Int(result[])
66-
elseif s == :device_address
67-
result = Ref{cl_mem_device_address_ext}()
68-
clGetMemObjectInfo(mem, CL_MEM_DEVICE_ADDRESS_EXT, sizeof(cl_mem_device_address_ext), result, C_NULL)
69-
return CLPtr{Cvoid}(result[])
70-
else
71-
return getfield(mem, s)
72-
end
73-
end
74-
75-
# convenience functions
76-
context(mem::AbstractMemoryObject) = mem.context
77-
Base.pointer(mem::AbstractMemoryObject) = mem.pointer
78-
79-
#TODO: enqueue_migrate_mem_objects(queue, mem_objects, flags=0, wait_for=None)
80-
#TODO: enqueue_migrate_mem_objects_ext(queue, mem_objects, flags=0, wait_for=None)
81-
821
# OpenCL.Buffer
832

843
struct Buffer <: AbstractMemoryObject
854
id::cl_mem
5+
ptr::Union{Nothing,CLPtr{Cvoid}}
866
bytesize::Int
7+
context::Context
878
end
889

10+
Buffer() = Buffer(C_NULL, nothing, 0, context())
11+
12+
Base.pointer(buf::Buffer) = @something buf.ptr error("Buffer does not have a device private address")
8913
Base.sizeof(buf::Buffer) = buf.bytesize
14+
context(buf::Buffer) = buf.context
9015

9116

9217
## constructors
@@ -130,7 +55,16 @@ function Buffer(sz::Int, flags::Integer, hostbuf=nothing;
13055
if err_code[] != CL_SUCCESS
13156
throw(CLError(err_code[]))
13257
end
133-
return Buffer(mem_id, sz)
58+
59+
ptr = if device_private_address
60+
ptr_ref = Ref{cl_mem_device_address_ext}()
61+
clGetMemObjectInfo(mem_id, CL_MEM_DEVICE_ADDRESS_EXT, sizeof(cl_mem_device_address_ext), ptr_ref, C_NULL)
62+
CLPtr{Cvoid}(ptr_ref[])
63+
else
64+
nothing
65+
end
66+
67+
return Buffer(mem_id, ptr, sz, context())
13468
end
13569

13670
# allocated buffer

lib/cl/memory/memory.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

lib/cl/memory/svm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct SharedVirtualMemory <: AbstractMemory
1+
struct SharedVirtualMemory <: AbstractPointerMemory
22
ptr::CLPtr{Cvoid}
33
bytesize::Int
44
context::Context

lib/cl/memory/usm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type UnifiedMemory <: AbstractMemory end
1+
abstract type UnifiedMemory <: AbstractPointerMemory end
22

33
function usm_free(mem::UnifiedMemory; blocking::Bool = false)
44
if blocking

lib/cl/state.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ function default_memory_backend(dev::Device)
199199

200200
backend = if backend_str == "usm"
201201
USMBackend()
202-
elseif backend_str == "bda"
203-
BDABackend()
204202
elseif backend_str == "svm"
205203
SVMBackend()
204+
elseif backend_str == "bda"
205+
BDABackend()
206206
else
207207
error("Unknown memory backend '$backend_str' requested")
208208
end

src/array.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function memory_type()
9999
elseif cl.memory_backend() == cl.SVMBackend()
100100
return cl.SharedVirtualMemory
101101
elseif cl.memory_backend() == cl.BDABackend()
102-
return cl.BufferDeviceMemory
102+
return cl.Buffer
103103
end
104104
end
105105
CLArray{T, N}(::UndefInitializer, dims::Dims{N}) where {T, N} =
@@ -175,11 +175,14 @@ context(A::CLArray) = cl.context(A.data[].mem)
175175
memtype(x::CLArray) = memtype(typeof(x))
176176
memtype(::Type{<:CLArray{<:Any, <:Any, M}}) where {M} = @isdefined(M) ? M : Any
177177

178-
is_device(a::CLArray) = memtype(a) == cl.UnifiedDeviceMemory
179-
is_shared(a::CLArray) = memtype(a) == cl.UnifiedSharedMemory
180-
is_host(a::CLArray) = memtype(a) == cl.UnifiedHostMemory
181-
is_svm(a::CLArray) = memtype(a) == cl.SharedVirtualMemory
182-
is_bda(a::CLArray) = memtype(a) == cl.BufferDeviceMemory
178+
# can we read this array from the device (i.e. derive a CLPtr)?
179+
is_device(a::CLArray) =
180+
memtype(a) in (cl.UnifiedDeviceMemory, cl.UnifiedSharedMemory, cl.SharedVirtualMemory, cl.Buffer)
181+
is_shared(a::CLArray) =
182+
memtype(a) in (cl.UnifiedSharedMemory, cl.SharedVirtualMemory)
183+
is_host(a::CLArray) =
184+
memtype(a) in (cl.UnifiedHostMemory, cl.UnifiedSharedMemory, cl.SharedVirtualMemory)
185+
183186

184187
## derived types
185188

@@ -283,13 +286,16 @@ end
283286
## interop with libraries
284287

285288
function Base.unsafe_convert(::Type{Ptr{T}}, x::CLArray{T}) where {T}
286-
if is_device(x)
289+
if !is_host(x)
287290
throw(ArgumentError("cannot take the CPU address of a $(typeof(x))"))
288291
end
289292
return convert(Ptr{T}, x.data[]) + x.offset * Base.elsize(x)
290293
end
291294

292295
function Base.unsafe_convert(::Type{CLPtr{T}}, x::CLArray{T}) where {T}
296+
if !is_device(x)
297+
throw(ArgumentError("cannot take the device address of a $(typeof(x))"))
298+
end
293299
return convert(CLPtr{T}, x.data[]) + x.offset * Base.elsize(x)
294300
end
295301

0 commit comments

Comments
 (0)