@@ -34,104 +34,9 @@ type ('a, 'b) t =
34
34
}
35
35
-> ('a , 'b ) t
36
36
37
- let is_device_available = function
38
- | Ocaml -> true
39
- | Metal -> Rune_metal. is_available
40
- | C -> (
41
- match Sys. backend_type with
42
- | Sys. (Native | Bytecode ) -> true
43
- | Sys. Other "js_of_ocaml" -> false
44
- | _ -> false )
45
-
46
- (* Ocaml_context creation *)
47
- let create_context ?(device = Ocaml ) () : context =
48
- match device with
49
- | Ocaml -> Ocaml_context (Nx_native. create_context () )
50
- | C ->
51
- if not (is_device_available C ) then
52
- failwith " C backend is not available on this platform"
53
- else C_context (Nx_c. create_context () )
54
- | Metal ->
55
- if not (is_device_available Metal ) then
56
- failwith " Metal backend is not available on this platform"
57
- else Metal_context (Rune_metal. create_context () )
58
-
59
- let default_device () : device_type =
60
- if is_device_available Metal then Metal
61
- else if is_device_available C then C
62
- else Ocaml
63
-
64
- let create_default_context () : context =
65
- create_context ~device: (default_device () ) ()
66
-
67
- (* Extract context from tensor *)
68
- let context : type a b. (a, b) t -> context = function
69
- | Ocaml_tensor cpu_t -> Ocaml_context (Nx_native. context cpu_t)
70
- | C_tensor c_t -> C_context (Nx_c. context c_t)
71
- | Metal_tensor metal_t -> Metal_context (Rune_metal. context metal_t)
72
- | Symbolic_tensor _ -> failwith " Symbolic tensors do not have a context"
73
-
74
- (* Device transfer operations *)
75
- let to_device (target_ctx : context ) (t : ('a, 'b) t ) : ('a, 'b) t =
76
- match (target_ctx, t) with
77
- (* Already on correct device *)
78
- | Ocaml_context _, Ocaml_tensor _
79
- | Metal_context _, Metal_tensor _
80
- | C_context _ , C_tensor _ ->
81
- t
82
- (* CPU to Metal *)
83
- | Metal_context metal_ctx , Ocaml_tensor cpu_t ->
84
- let data = Nx_native. data cpu_t in
85
- Metal_tensor (Rune_metal. op_const_array metal_ctx data)
86
- (* Metal to CPU *)
87
- | Ocaml_context ctx , Metal_tensor metal_t ->
88
- let data = Rune_metal. data metal_t in
89
- Ocaml_tensor (Nx_native. op_const_array ctx data)
90
- (* CPU to C *)
91
- | C_context c_ctx , Ocaml_tensor cpu_t ->
92
- let data = Nx_native. data cpu_t in
93
- C_tensor (Nx_c. op_const_array c_ctx data)
94
- (* C to CPU *)
95
- | Ocaml_context ctx , C_tensor c_t ->
96
- let data = Nx_c. data c_t in
97
- Ocaml_tensor (Nx_native. op_const_array ctx data)
98
- (* Metal to C *)
99
- | C_context c_ctx , Metal_tensor metal_t ->
100
- let data = Rune_metal. data metal_t in
101
- C_tensor (Nx_c. op_const_array c_ctx data)
102
- (* C to Metal *)
103
- | Metal_context metal_ctx , C_tensor c_t ->
104
- let data = Nx_c. data c_t in
105
- Metal_tensor (Rune_metal. op_const_array metal_ctx data)
106
- (* Symbolic tensors update their context *)
107
- | _ , Symbolic_tensor _ -> failwith " Cannot transfer symbolic tensor to device"
108
-
109
- (* Lenses *)
110
- let view : type a b. (a, b) t -> Lazy_view.t = function
111
- | Ocaml_tensor t -> Nx_native. view t
112
- | Metal_tensor t -> Rune_metal. view t
113
- | C_tensor t -> Nx_c. view t
114
- | Symbolic_tensor { shape; _ } ->
115
- Lazy_view. create (Symbolic_shape. of_ints shape)
116
-
117
- let dtype : type a b. (a, b) t -> (a, b) Dtype.t = function
118
- | Ocaml_tensor t -> Nx_native. dtype t
119
- | Metal_tensor t -> Rune_metal. dtype t
120
- | C_tensor t -> Nx_c. dtype t
121
- | Symbolic_tensor { dtype; _ } -> dtype
122
-
123
- let is_symbolic = function Symbolic_tensor _ -> true | _ -> false
124
-
125
- let data : type a b .
126
- (a , b ) t -> (a , b , Bigarray_ext. c_layout ) Bigarray_ext.Array1. t = function
127
- | Ocaml_tensor t -> Nx_native. data t
128
- | Metal_tensor t -> Rune_metal. data t
129
- | C_tensor t -> Nx_c. data t
130
- | Symbolic_tensor { id; _ } ->
131
- failwith (Printf. sprintf " Cannot extract data from symbolic tensor %d" id)
132
-
133
37
(* Effects - no context in most operations per new backend interface *)
134
38
type _ Effect.t + =
39
+ | E_view : ('a , 'b ) t -> Lazy_view .t Effect .t
135
40
| E_buffer : {
136
41
context : context ;
137
42
dtype : ('a , 'b ) Dtype .t ;
@@ -307,6 +212,106 @@ type _ Effect.t +=
307
212
s : int array option ;
308
213
}
309
214
-> (float , Dtype .float64_elt ) t Effect .t
215
+ | E_psum : { t_in : ('a , 'b ) t } -> ('a , 'b ) t Effect .t
216
+
217
+ let is_device_available = function
218
+ | Ocaml -> true
219
+ | Metal -> Rune_metal. is_available
220
+ | C -> (
221
+ match Sys. backend_type with
222
+ | Sys. (Native | Bytecode ) -> true
223
+ | Sys. Other "js_of_ocaml" -> false
224
+ | _ -> false )
225
+
226
+ (* Ocaml_context creation *)
227
+ let create_context ?(device = Ocaml ) () : context =
228
+ match device with
229
+ | Ocaml -> Ocaml_context (Nx_native. create_context () )
230
+ | C ->
231
+ if not (is_device_available C ) then
232
+ failwith " C backend is not available on this platform"
233
+ else C_context (Nx_c. create_context () )
234
+ | Metal ->
235
+ if not (is_device_available Metal ) then
236
+ failwith " Metal backend is not available on this platform"
237
+ else Metal_context (Rune_metal. create_context () )
238
+
239
+ let default_device () : device_type =
240
+ if is_device_available Metal then Metal
241
+ else if is_device_available C then C
242
+ else Ocaml
243
+
244
+ let create_default_context () : context =
245
+ create_context ~device: (default_device () ) ()
246
+
247
+ (* Extract context from tensor *)
248
+ let context : type a b. (a, b) t -> context = function
249
+ | Ocaml_tensor cpu_t -> Ocaml_context (Nx_native. context cpu_t)
250
+ | C_tensor c_t -> C_context (Nx_c. context c_t)
251
+ | Metal_tensor metal_t -> Metal_context (Rune_metal. context metal_t)
252
+ | Symbolic_tensor _ -> failwith " Symbolic tensors do not have a context"
253
+
254
+ (* Device transfer operations *)
255
+ let to_device (target_ctx : context ) (t : ('a, 'b) t ) : ('a, 'b) t =
256
+ match (target_ctx, t) with
257
+ (* Already on correct device *)
258
+ | Ocaml_context _, Ocaml_tensor _
259
+ | Metal_context _, Metal_tensor _
260
+ | C_context _ , C_tensor _ ->
261
+ t
262
+ (* CPU to Metal *)
263
+ | Metal_context metal_ctx , Ocaml_tensor cpu_t ->
264
+ let data = Nx_native. data cpu_t in
265
+ Metal_tensor (Rune_metal. op_const_array metal_ctx data)
266
+ (* Metal to CPU *)
267
+ | Ocaml_context ctx , Metal_tensor metal_t ->
268
+ let data = Rune_metal. data metal_t in
269
+ Ocaml_tensor (Nx_native. op_const_array ctx data)
270
+ (* CPU to C *)
271
+ | C_context c_ctx , Ocaml_tensor cpu_t ->
272
+ let data = Nx_native. data cpu_t in
273
+ C_tensor (Nx_c. op_const_array c_ctx data)
274
+ (* C to CPU *)
275
+ | Ocaml_context ctx , C_tensor c_t ->
276
+ let data = Nx_c. data c_t in
277
+ Ocaml_tensor (Nx_native. op_const_array ctx data)
278
+ (* Metal to C *)
279
+ | C_context c_ctx , Metal_tensor metal_t ->
280
+ let data = Rune_metal. data metal_t in
281
+ C_tensor (Nx_c. op_const_array c_ctx data)
282
+ (* C to Metal *)
283
+ | Metal_context metal_ctx , C_tensor c_t ->
284
+ let data = Nx_c. data c_t in
285
+ Metal_tensor (Rune_metal. op_const_array metal_ctx data)
286
+ (* Symbolic tensors update their context *)
287
+ | _ , Symbolic_tensor _ -> failwith " Cannot transfer symbolic tensor to device"
288
+
289
+ (* Lenses *)
290
+ let view (type a b ) (x : (a, b) t ) : Lazy_view.t =
291
+ try Effect. perform (E_view x)
292
+ with Effect. Unhandled _ -> (
293
+ match x with
294
+ | Ocaml_tensor t -> Nx_native. view t
295
+ | Metal_tensor t -> Rune_metal. view t
296
+ | C_tensor t -> Nx_c. view t
297
+ | Symbolic_tensor { shape; _ } ->
298
+ Lazy_view. create (Symbolic_shape. of_ints shape))
299
+
300
+ let dtype : type a b. (a, b) t -> (a, b) Dtype.t = function
301
+ | Ocaml_tensor t -> Nx_native. dtype t
302
+ | Metal_tensor t -> Rune_metal. dtype t
303
+ | C_tensor t -> Nx_c. dtype t
304
+ | Symbolic_tensor { dtype; _ } -> dtype
305
+
306
+ let is_symbolic = function Symbolic_tensor _ -> true | _ -> false
307
+
308
+ let data : type a b .
309
+ (a , b ) t -> (a , b , Bigarray_ext. c_layout ) Bigarray_ext.Array1. t = function
310
+ | Ocaml_tensor t -> Nx_native. data t
311
+ | Metal_tensor t -> Rune_metal. data t
312
+ | C_tensor t -> Nx_c. data t
313
+ | Symbolic_tensor { id; _ } ->
314
+ failwith (Printf. sprintf " Cannot extract data from symbolic tensor %d" id)
310
315
311
316
(* Helper functions for different operation types *)
312
317
@@ -512,6 +517,12 @@ let op_recip t_in =
512
517
(fun () -> E_recip { t_in })
513
518
Nx_native. op_recip Rune_metal. op_recip Nx_c. op_recip t_in
514
519
520
+ (* Collective primitive: parallel sum across mapped axis, to be handled by
521
+ vmap. *)
522
+ let op_psum t_in =
523
+ try Effect. perform (E_psum { t_in })
524
+ with Effect. Unhandled _ -> failwith " psum must be used under vmap"
525
+
515
526
(* Reduction operations *)
516
527
let op_reduce_sum ~axes ~keepdims t_in =
517
528
reduce_op
0 commit comments