Skip to content

Commit f880b21

Browse files
committed
feat(rune): Work-in-progress vmap implementation
1 parent 8ab77cd commit f880b21

File tree

4 files changed

+1878
-983
lines changed

4 files changed

+1878
-983
lines changed

rune/lib/nx_rune.ml

Lines changed: 107 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -34,104 +34,9 @@ type ('a, 'b) t =
3434
}
3535
-> ('a, 'b) t
3636

37-
let is_device_available = function
38-
| Ocaml -> true
39-
| Metal -> Rune_metal.is_available
40-
| C -> (
41-
match Sys.backend_type with
42-
| Sys.(Native | Bytecode) -> true
43-
| Sys.Other "js_of_ocaml" -> false
44-
| _ -> false)
45-
46-
(* Ocaml_context creation *)
47-
let create_context ?(device = Ocaml) () : context =
48-
match device with
49-
| Ocaml -> Ocaml_context (Nx_native.create_context ())
50-
| C ->
51-
if not (is_device_available C) then
52-
failwith "C backend is not available on this platform"
53-
else C_context (Nx_c.create_context ())
54-
| Metal ->
55-
if not (is_device_available Metal) then
56-
failwith "Metal backend is not available on this platform"
57-
else Metal_context (Rune_metal.create_context ())
58-
59-
let default_device () : device_type =
60-
if is_device_available Metal then Metal
61-
else if is_device_available C then C
62-
else Ocaml
63-
64-
let create_default_context () : context =
65-
create_context ~device:(default_device ()) ()
66-
67-
(* Extract context from tensor *)
68-
let context : type a b. (a, b) t -> context = function
69-
| Ocaml_tensor cpu_t -> Ocaml_context (Nx_native.context cpu_t)
70-
| C_tensor c_t -> C_context (Nx_c.context c_t)
71-
| Metal_tensor metal_t -> Metal_context (Rune_metal.context metal_t)
72-
| Symbolic_tensor _ -> failwith "Symbolic tensors do not have a context"
73-
74-
(* Device transfer operations *)
75-
let to_device (target_ctx : context) (t : ('a, 'b) t) : ('a, 'b) t =
76-
match (target_ctx, t) with
77-
(* Already on correct device *)
78-
| Ocaml_context _, Ocaml_tensor _
79-
| Metal_context _, Metal_tensor _
80-
| C_context _, C_tensor _ ->
81-
t
82-
(* CPU to Metal *)
83-
| Metal_context metal_ctx, Ocaml_tensor cpu_t ->
84-
let data = Nx_native.data cpu_t in
85-
Metal_tensor (Rune_metal.op_const_array metal_ctx data)
86-
(* Metal to CPU *)
87-
| Ocaml_context ctx, Metal_tensor metal_t ->
88-
let data = Rune_metal.data metal_t in
89-
Ocaml_tensor (Nx_native.op_const_array ctx data)
90-
(* CPU to C *)
91-
| C_context c_ctx, Ocaml_tensor cpu_t ->
92-
let data = Nx_native.data cpu_t in
93-
C_tensor (Nx_c.op_const_array c_ctx data)
94-
(* C to CPU *)
95-
| Ocaml_context ctx, C_tensor c_t ->
96-
let data = Nx_c.data c_t in
97-
Ocaml_tensor (Nx_native.op_const_array ctx data)
98-
(* Metal to C *)
99-
| C_context c_ctx, Metal_tensor metal_t ->
100-
let data = Rune_metal.data metal_t in
101-
C_tensor (Nx_c.op_const_array c_ctx data)
102-
(* C to Metal *)
103-
| Metal_context metal_ctx, C_tensor c_t ->
104-
let data = Nx_c.data c_t in
105-
Metal_tensor (Rune_metal.op_const_array metal_ctx data)
106-
(* Symbolic tensors update their context *)
107-
| _, Symbolic_tensor _ -> failwith "Cannot transfer symbolic tensor to device"
108-
109-
(* Lenses *)
110-
let view : type a b. (a, b) t -> Lazy_view.t = function
111-
| Ocaml_tensor t -> Nx_native.view t
112-
| Metal_tensor t -> Rune_metal.view t
113-
| C_tensor t -> Nx_c.view t
114-
| Symbolic_tensor { shape; _ } ->
115-
Lazy_view.create (Symbolic_shape.of_ints shape)
116-
117-
let dtype : type a b. (a, b) t -> (a, b) Dtype.t = function
118-
| Ocaml_tensor t -> Nx_native.dtype t
119-
| Metal_tensor t -> Rune_metal.dtype t
120-
| C_tensor t -> Nx_c.dtype t
121-
| Symbolic_tensor { dtype; _ } -> dtype
122-
123-
let is_symbolic = function Symbolic_tensor _ -> true | _ -> false
124-
125-
let data : type a b.
126-
(a, b) t -> (a, b, Bigarray_ext.c_layout) Bigarray_ext.Array1.t = function
127-
| Ocaml_tensor t -> Nx_native.data t
128-
| Metal_tensor t -> Rune_metal.data t
129-
| C_tensor t -> Nx_c.data t
130-
| Symbolic_tensor { id; _ } ->
131-
failwith (Printf.sprintf "Cannot extract data from symbolic tensor %d" id)
132-
13337
(* Effects - no context in most operations per new backend interface *)
13438
type _ Effect.t +=
39+
| E_view : ('a, 'b) t -> Lazy_view.t Effect.t
13540
| E_buffer : {
13641
context : context;
13742
dtype : ('a, 'b) Dtype.t;
@@ -307,6 +212,106 @@ type _ Effect.t +=
307212
s : int array option;
308213
}
309214
-> (float, Dtype.float64_elt) t Effect.t
215+
| E_psum : { t_in : ('a, 'b) t } -> ('a, 'b) t Effect.t
216+
217+
let is_device_available = function
218+
| Ocaml -> true
219+
| Metal -> Rune_metal.is_available
220+
| C -> (
221+
match Sys.backend_type with
222+
| Sys.(Native | Bytecode) -> true
223+
| Sys.Other "js_of_ocaml" -> false
224+
| _ -> false)
225+
226+
(* Ocaml_context creation *)
227+
let create_context ?(device = Ocaml) () : context =
228+
match device with
229+
| Ocaml -> Ocaml_context (Nx_native.create_context ())
230+
| C ->
231+
if not (is_device_available C) then
232+
failwith "C backend is not available on this platform"
233+
else C_context (Nx_c.create_context ())
234+
| Metal ->
235+
if not (is_device_available Metal) then
236+
failwith "Metal backend is not available on this platform"
237+
else Metal_context (Rune_metal.create_context ())
238+
239+
let default_device () : device_type =
240+
if is_device_available Metal then Metal
241+
else if is_device_available C then C
242+
else Ocaml
243+
244+
let create_default_context () : context =
245+
create_context ~device:(default_device ()) ()
246+
247+
(* Extract context from tensor *)
248+
let context : type a b. (a, b) t -> context = function
249+
| Ocaml_tensor cpu_t -> Ocaml_context (Nx_native.context cpu_t)
250+
| C_tensor c_t -> C_context (Nx_c.context c_t)
251+
| Metal_tensor metal_t -> Metal_context (Rune_metal.context metal_t)
252+
| Symbolic_tensor _ -> failwith "Symbolic tensors do not have a context"
253+
254+
(* Device transfer operations *)
255+
let to_device (target_ctx : context) (t : ('a, 'b) t) : ('a, 'b) t =
256+
match (target_ctx, t) with
257+
(* Already on correct device *)
258+
| Ocaml_context _, Ocaml_tensor _
259+
| Metal_context _, Metal_tensor _
260+
| C_context _, C_tensor _ ->
261+
t
262+
(* CPU to Metal *)
263+
| Metal_context metal_ctx, Ocaml_tensor cpu_t ->
264+
let data = Nx_native.data cpu_t in
265+
Metal_tensor (Rune_metal.op_const_array metal_ctx data)
266+
(* Metal to CPU *)
267+
| Ocaml_context ctx, Metal_tensor metal_t ->
268+
let data = Rune_metal.data metal_t in
269+
Ocaml_tensor (Nx_native.op_const_array ctx data)
270+
(* CPU to C *)
271+
| C_context c_ctx, Ocaml_tensor cpu_t ->
272+
let data = Nx_native.data cpu_t in
273+
C_tensor (Nx_c.op_const_array c_ctx data)
274+
(* C to CPU *)
275+
| Ocaml_context ctx, C_tensor c_t ->
276+
let data = Nx_c.data c_t in
277+
Ocaml_tensor (Nx_native.op_const_array ctx data)
278+
(* Metal to C *)
279+
| C_context c_ctx, Metal_tensor metal_t ->
280+
let data = Rune_metal.data metal_t in
281+
C_tensor (Nx_c.op_const_array c_ctx data)
282+
(* C to Metal *)
283+
| Metal_context metal_ctx, C_tensor c_t ->
284+
let data = Nx_c.data c_t in
285+
Metal_tensor (Rune_metal.op_const_array metal_ctx data)
286+
(* Symbolic tensors update their context *)
287+
| _, Symbolic_tensor _ -> failwith "Cannot transfer symbolic tensor to device"
288+
289+
(* Lenses *)
290+
let view (type a b) (x : (a, b) t) : Lazy_view.t =
291+
try Effect.perform (E_view x)
292+
with Effect.Unhandled _ -> (
293+
match x with
294+
| Ocaml_tensor t -> Nx_native.view t
295+
| Metal_tensor t -> Rune_metal.view t
296+
| C_tensor t -> Nx_c.view t
297+
| Symbolic_tensor { shape; _ } ->
298+
Lazy_view.create (Symbolic_shape.of_ints shape))
299+
300+
let dtype : type a b. (a, b) t -> (a, b) Dtype.t = function
301+
| Ocaml_tensor t -> Nx_native.dtype t
302+
| Metal_tensor t -> Rune_metal.dtype t
303+
| C_tensor t -> Nx_c.dtype t
304+
| Symbolic_tensor { dtype; _ } -> dtype
305+
306+
let is_symbolic = function Symbolic_tensor _ -> true | _ -> false
307+
308+
let data : type a b.
309+
(a, b) t -> (a, b, Bigarray_ext.c_layout) Bigarray_ext.Array1.t = function
310+
| Ocaml_tensor t -> Nx_native.data t
311+
| Metal_tensor t -> Rune_metal.data t
312+
| C_tensor t -> Nx_c.data t
313+
| Symbolic_tensor { id; _ } ->
314+
failwith (Printf.sprintf "Cannot extract data from symbolic tensor %d" id)
310315

311316
(* Helper functions for different operation types *)
312317

@@ -512,6 +517,12 @@ let op_recip t_in =
512517
(fun () -> E_recip { t_in })
513518
Nx_native.op_recip Rune_metal.op_recip Nx_c.op_recip t_in
514519

520+
(* Collective primitive: parallel sum across mapped axis, to be handled by
521+
vmap. *)
522+
let op_psum t_in =
523+
try Effect.perform (E_psum { t_in })
524+
with Effect.Unhandled _ -> failwith "psum must be used under vmap"
525+
515526
(* Reduction operations *)
516527
let op_reduce_sum ~axes ~keepdims t_in =
517528
reduce_op

0 commit comments

Comments
 (0)