diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 284919d57..34cb95033 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -15,16 +15,6 @@ jobs: os: - ubuntu-latest ocaml-compiler: - - "4.08" - - "4.09" - - "4.10" - - "4.11" - - "4.12" - - "4.13" - - "4.14" - - "5.0" - - "5.1" - - "5.2" - "5.3" libev: - true diff --git a/CHANGES b/CHANGES index 76a6cdb80..4f9d5b9cc 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,9 @@ +===== 6.0.0 ===== + + * Support multiple scheduler running in parallel in separate domains. + + * Exception filter defaults to letting systems exceptions through. + ===== 5.9.0 ===== ====== Additions ====== diff --git a/dune-project b/dune-project index 5f0af6dfa..033af129f 100644 --- a/dune-project +++ b/dune-project @@ -57,7 +57,8 @@ a single thread by default. This reduces the need for locks or other synchronization primitives. Code can be run in parallel on an opt-in basis. ") (depends - (ocaml (>= 4.08)) + (ocaml (>= 4.14)) + domain_shims (cppo (and :build (>= 1.1.0))) (ocamlfind (and :dev (>= 1.7.3-1))) (odoc (and :with-doc (>= 2.3.0))) diff --git a/lwt.opam b/lwt.opam index 6019f80d1..264c6d7d3 100644 --- a/lwt.opam +++ b/lwt.opam @@ -21,7 +21,8 @@ doc: "https://ocsigen.org/lwt" bug-reports: "https://github.com/ocsigen/lwt/issues" depends: [ "dune" {>= "2.7"} - "ocaml" {>= "4.08"} + "ocaml" {>= "4.14"} + "domain_shims" "cppo" {build & >= "1.1.0"} "ocamlfind" {dev & >= "1.7.3-1"} "odoc" {with-doc & >= "2.3.0"} diff --git a/src/core/domain_map.ml b/src/core/domain_map.ml new file mode 100644 index 000000000..f80232f19 --- /dev/null +++ b/src/core/domain_map.ml @@ -0,0 +1,57 @@ +module Domain_map : Map.S with type key = Domain.id = Map.Make(struct + type t = Domain.id + let compare d1 d2 = Int.compare (d1 : Domain.id :> int) (d2 : Domain.id :> int) +end) + +(* Protected domain map reference with per-reference mutex *) +type 'a protected_map = { + mutex : Mutex.t; + mutable map : 'a Domain_map.t; +} + +let create_protected_map () = { + mutex = Mutex.create (); + map = Domain_map.empty; +} + +let with_lock protected_map f = + Mutex.lock protected_map.mutex; + Fun.protect f ~finally:(fun () -> Mutex.unlock protected_map.mutex) + +let update_map protected_map f = + with_lock protected_map (fun () -> + let old_map = protected_map.map in + let new_map = f old_map in + protected_map.map <- new_map) + +let add protected_map key value = + update_map protected_map (Domain_map.add key value) + +let remove protected_map key = + update_map protected_map (Domain_map.remove key) + +let update protected_map key f = + update_map protected_map (Domain_map.update key f) + +let find protected_map key = + with_lock protected_map (fun () -> Domain_map.find_opt key protected_map.map) + +let extract protected_map key = + with_lock protected_map (fun () -> + match Domain_map.find_opt key protected_map.map with + | None -> None + | Some v -> + protected_map.map <- Domain_map.remove key protected_map.map; + Some v) + +let size protected_map = + with_lock protected_map (fun () -> Domain_map.cardinal protected_map.map) + +let init protected_map key init_value = + with_lock protected_map (fun () -> + match Domain_map.find_opt key protected_map.map with + | Some existing -> existing + | None -> + let new_value = init_value () in + protected_map.map <- Domain_map.add key new_value protected_map.map; + new_value) diff --git a/src/core/domain_map.mli b/src/core/domain_map.mli new file mode 100644 index 000000000..008bc81be --- /dev/null +++ b/src/core/domain_map.mli @@ -0,0 +1,38 @@ +(** Domain-indexed maps with thread-safe operations + + Only intended to use internally, not for general release. + + Note that these function use a lock. A single lock. + - Probably not optimal + - Deadlock if you call one of those functions inside another (e.g., use + `init` rather than `find`+`update` + *) + +(** Thread-safe wrapper for domain maps *) +type 'a protected_map + +(** Create a new protected map with an empty map inside and a dedicated mutex, + the map is keyed on domain ids, and operations are synchronised via a mutex. + *) +val create_protected_map : unit -> 'a protected_map + +(** Add a key-value binding to the map *) +val add : 'a protected_map -> Domain.id -> 'a -> unit + +(** Remove a key from the map *) +val remove : 'a protected_map -> Domain.id -> unit + +(** Update a binding using the underlying map's update function *) +val update : 'a protected_map -> Domain.id -> ('a option -> 'a option) -> unit + +(** Find a value by key, returning None if not found *) +val find : 'a protected_map -> Domain.id -> 'a option + +(** Find + remove but hit the mutex only once *) +val extract : 'a protected_map -> Domain.id -> 'a option + +(** Get the number of bindings in the map *) +val size : 'a protected_map -> int + +(** Initialize a key with a value if it doesn't exist, return existing or new value *) +val init : 'a protected_map -> Domain.id -> (unit -> 'a) -> 'a diff --git a/src/core/dune b/src/core/dune index dab7ccc8d..cdc69e89a 100644 --- a/src/core/dune +++ b/src/core/dune @@ -2,6 +2,7 @@ (public_name lwt) (synopsis "Monadic promises and concurrent I/O") (wrapped false) + (libraries domain_shims) (instrumentation (backend bisect_ppx))) diff --git a/src/core/lwt.ml b/src/core/lwt.ml index 257134c63..06b6ac989 100644 --- a/src/core/lwt.ml +++ b/src/core/lwt.ml @@ -365,6 +365,60 @@ module Storage_map = type storage = (unit -> unit) Storage_map.t +(* callback_exchange is a domain-indexed map for storing callbacks that + different domains should execute. This is used when a domain d1 resolves a + promise on which a different domain d2 has attached callbacks (implicitely + via bind etc. or explicitly via on_success etc.). When this happens, the + domain resolving the promise calls its local callbacks and sends the other + domains' callbacks into the callback exchange *) +let callback_exchange = Domain_map.create_protected_map () + +(* notification_map is a domain-indexed map for waking sleeping domains. each + (should) domain registers a notification (see Lwt_unix) into the map when it + starts its scheduler. other domains can wake the domain up to indicate that + callbacks are available to be called *) +let notification_map = Domain_map.create_protected_map () + +(* send_callback d cb adds the callback cb into the callback_exchange and pings + the domain d via the notification_map *) +let send_callback d cb = + Domain_map.update + callback_exchange + d + (function + | None -> + let cbs = Lwt_sequence.create () in + let _ : (unit -> unit) Lwt_sequence.node = Lwt_sequence.add_l cb cbs in + Some cbs + | Some cbs -> + let _ : (unit -> unit) Lwt_sequence.node = Lwt_sequence.add_l cb cbs in + Some cbs); + begin match Domain_map.find notification_map d with + | None -> + failwith "ERROR: domain didn't register at startup" + | Some n -> + n () + end + +(* get_sent_callbacks gets a domain's own callback from the callbasck exchange, + this is so that the notification handler installed by main.run can obtain the + callbacks that have been sent its way *) +let get_sent_callbacks domain_id = + match Domain_map.extract callback_exchange domain_id with + | None -> Lwt_sequence.create () + | Some cbs -> cbs + +(* register_notification adds a domain's own notification (see Lwt_unix) into + the notification map *) +let register_notification d n = + Domain_map.update notification_map d (function + | None -> Some n + | Some _ -> failwith "already registered!!") + +let is_alredy_registered d = + match Domain_map.find notification_map d with + | Some _ -> true + | None -> false module Main_internal_types = struct @@ -452,9 +506,9 @@ struct | Regular_callback_list_concat of 'a regular_callback_list * 'a regular_callback_list | Regular_callback_list_implicitly_removed_callback of - 'a regular_callback + Domain.id * 'a regular_callback | Regular_callback_list_explicitly_removable_callback of - 'a regular_callback option ref + Domain.id * 'a regular_callback option ref and _ cancel_callback_list = | Cancel_callback_list_empty : @@ -463,10 +517,10 @@ struct 'a cancel_callback_list * 'a cancel_callback_list -> 'a cancel_callback_list | Cancel_callback_list_callback : - storage * cancel_callback -> + Domain.id * storage * cancel_callback -> _ cancel_callback_list | Cancel_callback_list_remove_sequence_node : - ('a, _, _) promise Lwt_sequence.node -> + Domain.id * ('a, _, _) promise Lwt_sequence.node -> 'a cancel_callback_list (* Notes: @@ -716,11 +770,9 @@ module Exception_filter = struct | Out_of_memory -> false | Stack_overflow -> false | _ -> true - let v = - (* Default value: the legacy behaviour to avoid breaking programs *) - ref handle_all - let set f = v := f - let run e = !v e + let v = Atomic.make handle_all_except_runtime + let set f = Atomic.set v f + let run e = (Atomic.get v) e end module Sequence_associated_storage : @@ -732,7 +784,7 @@ sig val with_value : 'v key -> 'v option -> (unit -> 'b) -> 'b (* Internal interface *) - val current_storage : storage ref + val current_storage : storage Domain.DLS.key end = struct (* The idea behind sequence-associated storage is to preserve some values @@ -766,18 +818,17 @@ struct mutable value : 'v option; } - let next_key_id = ref 0 + let next_key_id = Atomic.make 0 let new_key () = - let id = !next_key_id in - next_key_id := id + 1; + let id = Atomic.fetch_and_add next_key_id 1 in {id = id; value = None} - let current_storage = ref Storage_map.empty + let current_storage = Domain.DLS.new_key (fun () -> Storage_map.empty) let get key = - if Storage_map.mem key.id !current_storage then begin - let refresh = Storage_map.find key.id !current_storage in + if Storage_map.mem key.id (Domain.DLS.get current_storage) then begin + let refresh = Storage_map.find key.id (Domain.DLS.get current_storage) in refresh (); let value = key.value in key.value <- None; @@ -791,19 +842,19 @@ struct match value with | Some _ -> let refresh = fun () -> key.value <- value in - Storage_map.add key.id refresh !current_storage + Storage_map.add key.id refresh (Domain.DLS.get current_storage) | None -> - Storage_map.remove key.id !current_storage + Storage_map.remove key.id (Domain.DLS.get current_storage) in - let saved_storage = !current_storage in - current_storage := new_storage; + let saved_storage = (Domain.DLS.get current_storage) in + Domain.DLS.set current_storage new_storage; try let result = f () in - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; result with exn when Exception_filter.run exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; raise exn end include Sequence_associated_storage @@ -840,10 +891,10 @@ struct (* In a callback list, filters out cells of explicitly removable callbacks that have been removed. *) let rec clean_up_callback_cells = function - | Regular_callback_list_explicitly_removable_callback {contents = None} -> + | Regular_callback_list_explicitly_removable_callback (_, {contents = None}) -> Regular_callback_list_empty - | Regular_callback_list_explicitly_removable_callback {contents = Some _} + | Regular_callback_list_explicitly_removable_callback (_, {contents = Some _}) | Regular_callback_list_implicitly_removed_callback _ | Regular_callback_list_empty as callbacks -> callbacks @@ -954,7 +1005,7 @@ struct let add_implicitly_removed_callback callbacks f = add_regular_callback_list_node - callbacks (Regular_callback_list_implicitly_removed_callback f) + callbacks (Regular_callback_list_implicitly_removed_callback (Domain.self (), f)) (* Adds [callback] as removable to each promise in [ps]. The first promise in [ps] to trigger [callback] removes [callback] from the other promises; this @@ -970,7 +1021,7 @@ struct f result in - let node = Regular_callback_list_explicitly_removable_callback cell in + let node = Regular_callback_list_explicitly_removable_callback (Domain.self (), cell) in ps |> List.iter (fun p -> let Internal p = to_internal_promise p in match (underlying p).state with @@ -991,7 +1042,7 @@ struct clear_explicitly_removable_callback_cell cell ~originally_added_to:ps let add_cancel_callback callbacks f = - let node = Cancel_callback_list_callback (!current_storage, f) in + let node = Cancel_callback_list_callback (Domain.self (), (Domain.DLS.get current_storage), f) in callbacks.cancel_callbacks <- match callbacks.cancel_callbacks with @@ -1166,12 +1217,23 @@ struct match fs with | Cancel_callback_list_empty -> iter_list rest - | Cancel_callback_list_callback (storage, f) -> - current_storage := storage; - handle_with_async_exception_hook f (); - iter_list rest - | Cancel_callback_list_remove_sequence_node node -> - Lwt_sequence.remove node; + | Cancel_callback_list_callback (domain, storage, f) -> + begin if domain = Domain.self () then begin + Domain.DLS.set current_storage storage; + handle_with_async_exception_hook f () + end else + send_callback domain (fun () -> + Domain.DLS.set current_storage storage; + handle_with_async_exception_hook f () + ) + end; + iter_list rest + | Cancel_callback_list_remove_sequence_node (domain, node) -> + begin if domain = Domain.self () then + Lwt_sequence.remove node + else + send_callback domain (fun () -> Lwt_sequence.remove node) + end; iter_list rest | Cancel_callback_list_concat (fs, fs') -> iter_callback_list fs (fs'::rest) @@ -1191,16 +1253,22 @@ struct match fs with | Regular_callback_list_empty -> iter_list rest - | Regular_callback_list_implicitly_removed_callback f -> - f result; - iter_list rest - | Regular_callback_list_explicitly_removable_callback - {contents = None} -> - iter_list rest - | Regular_callback_list_explicitly_removable_callback - {contents = Some f} -> - f result; - iter_list rest + | Regular_callback_list_implicitly_removed_callback (domain, f) -> + begin if domain = Domain.self () then + f result + else + send_callback domain (fun () -> f result) + end; + iter_list rest + | Regular_callback_list_explicitly_removable_callback (_, {contents = None}) -> + iter_list rest + | Regular_callback_list_explicitly_removable_callback (domain, {contents = Some f}) -> + begin if domain = Domain.self () then + f result + else + send_callback domain (fun () -> f result) + end; + iter_list rest | Regular_callback_list_concat (fs, fs') -> iter_callback_list fs (fs'::rest) @@ -1229,7 +1297,7 @@ struct let default_maximum_callback_nesting_depth = 42 - let current_callback_nesting_depth = ref 0 + let current_callback_nesting_depth = Domain.DLS.new_key (fun () -> 0) type deferred_callbacks = Deferred : ('a callbacks * 'a resolved_state) -> deferred_callbacks @@ -1242,19 +1310,19 @@ struct the callbacks that will be run will modify the storage. The storage is restored to the snapshot when the resolution loop is exited. *) let enter_resolution_loop () = - current_callback_nesting_depth := !current_callback_nesting_depth + 1; - let storage_snapshot = !current_storage in + Domain.DLS.set current_callback_nesting_depth (Domain.DLS.get current_callback_nesting_depth + 1); + let storage_snapshot = (Domain.DLS.get current_storage) in storage_snapshot let leave_resolution_loop (storage_snapshot : storage) : unit = - if !current_callback_nesting_depth = 1 then begin + if Domain.DLS.get current_callback_nesting_depth = 1 then begin while not (Queue.is_empty deferred_callbacks) do let Deferred (callbacks, result) = Queue.pop deferred_callbacks in run_callbacks callbacks result done end; - current_callback_nesting_depth := !current_callback_nesting_depth - 1; - current_storage := storage_snapshot + Domain.DLS.set current_callback_nesting_depth (Domain.DLS.get current_callback_nesting_depth - 1); + Domain.DLS.set current_storage storage_snapshot let run_in_resolution_loop f = let storage_snapshot = enter_resolution_loop () in @@ -1269,7 +1337,7 @@ struct The name should probably be [abaondon_resolution_loop]. *) let abandon_wakeups () = - if !current_callback_nesting_depth <> 0 then + if Domain.DLS.get current_callback_nesting_depth <> 0 then leave_resolution_loop Storage_map.empty @@ -1281,7 +1349,7 @@ struct let should_defer = allow_deferring - && !current_callback_nesting_depth >= maximum_callback_nesting_depth + && Domain.DLS.get current_callback_nesting_depth >= maximum_callback_nesting_depth in if should_defer then @@ -1309,7 +1377,7 @@ struct else let should_defer = - !current_callback_nesting_depth + Domain.DLS.get current_callback_nesting_depth >= default_maximum_callback_nesting_depth in @@ -1320,7 +1388,7 @@ struct { regular_callbacks = Regular_callback_list_implicitly_removed_callback - deferred_callback; + (Domain.self (), deferred_callback); cancel_callbacks = Cancel_callback_list_empty; how_to_cancel = Not_cancelable; cleanups_deferred = 0 @@ -1576,7 +1644,7 @@ struct let Pending callbacks = p.state in callbacks.cancel_callbacks <- - Cancel_callback_list_remove_sequence_node node; + Cancel_callback_list_remove_sequence_node (Domain.self (), node); to_public_promise p @@ -1587,7 +1655,7 @@ struct let Pending callbacks = p.state in callbacks.cancel_callbacks <- - Cancel_callback_list_remove_sequence_node node; + Cancel_callback_list_remove_sequence_node (Domain.self (), node); to_public_promise p @@ -1831,12 +1899,12 @@ struct [p''] will be equivalent to trying to cancel [p'], so the behavior will depend on how the user obtained [p']. *) - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f v with exn @@ -1897,12 +1965,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f v @@ -1954,12 +2022,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p''_result = try Fulfilled (f v) with exn @@ -2020,7 +2088,7 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with @@ -2033,7 +2101,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2081,7 +2149,7 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with @@ -2094,7 +2162,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2143,12 +2211,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f' v @@ -2164,7 +2232,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2218,12 +2286,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f' v @@ -2240,7 +2308,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2324,12 +2392,12 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun result -> match result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f v | Rejected _ -> @@ -2357,7 +2425,7 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun result -> match result with @@ -2365,7 +2433,7 @@ struct () | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f exn in @@ -2390,10 +2458,10 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun _result -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f () in @@ -2423,16 +2491,16 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun result -> match result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f v | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook g exn in @@ -3161,34 +3229,34 @@ struct - let pause_hook = ref ignore + let pause_hook = Domain.DLS.new_key (fun () -> ignore) - let paused = Lwt_sequence.create () - let paused_count = ref 0 + let paused = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) + let paused_count = Domain.DLS.new_key (fun () -> 0) let pause () = - let p = add_task_r paused in - incr paused_count; - !pause_hook !paused_count; + let p = add_task_r (Domain.DLS.get paused) in + Domain.DLS.set paused_count (Domain.DLS.get paused_count + 1); + (Domain.DLS.get pause_hook) (Domain.DLS.get paused_count); p let wakeup_paused () = - if Lwt_sequence.is_empty paused then - paused_count := 0 + if Lwt_sequence.is_empty (Domain.DLS.get paused) then + Domain.DLS.set paused_count 0 else begin let tmp = Lwt_sequence.create () in - Lwt_sequence.transfer_r paused tmp; - paused_count := 0; + Lwt_sequence.transfer_r (Domain.DLS.get paused) tmp; + Domain.DLS.set paused_count 0; Lwt_sequence.iter_l (fun r -> wakeup r ()) tmp end - let register_pause_notifier f = pause_hook := f + let register_pause_notifier f = Domain.DLS.set pause_hook f let abandon_paused () = - Lwt_sequence.clear paused; - paused_count := 0 + Lwt_sequence.clear (Domain.DLS.get paused); + Domain.DLS.set paused_count 0 - let paused_count () = !paused_count + let paused_count () = Domain.DLS.get paused_count end include Miscellaneous diff --git a/src/core/lwt.mli b/src/core/lwt.mli index 7598343d8..58c43d5df 100644 --- a/src/core/lwt.mli +++ b/src/core/lwt.mli @@ -2061,3 +2061,9 @@ val backtrace_try_bind : val abandon_wakeups : unit -> unit val debug_state_is : 'a state -> 'a t -> bool t + +[@@@ocaml.warning "-3"] +(* this is only for cross-domain scheduler synchronisation *) +val get_sent_callbacks : Domain.id -> (unit -> unit) Lwt_sequence.t +val register_notification : Domain.id -> (unit -> unit) -> unit +val is_alredy_registered : Domain.id -> bool diff --git a/src/unix/dune b/src/unix/dune index a5c6a3977..a548de2fb 100644 --- a/src/unix/dune +++ b/src/unix/dune @@ -191,6 +191,6 @@ (flags (:include unix_c_flags.sexp))) (c_library_flags - (:include unix_c_library_flags.sexp)) + (:include unix_c_library_flags.sexp) -fPIC -pthread) (instrumentation (backend bisect_ppx))) diff --git a/src/unix/lwt_engine.ml b/src/unix/lwt_engine.ml index 20a8eafc7..a2c6ba3cb 100644 --- a/src/unix/lwt_engine.ml +++ b/src/unix/lwt_engine.ml @@ -416,29 +416,30 @@ end +-----------------------------------------------------------------+ *) let current = - if Lwt_config._HAVE_LIBEV && Lwt_config.libev_default then - ref (new libev () :> t) - else - ref (new select :> t) + Domain.DLS.new_key (fun () -> + if Lwt_config._HAVE_LIBEV && Lwt_config.libev_default then + (new libev () :> t) + else + (new select :> t) +) -let get () = - !current +let get () = Domain.DLS.get current let set ?(transfer=true) ?(destroy=true) engine = - if transfer then !current#transfer (engine : #t :> abstract); - if destroy then !current#destroy; - current := (engine : #t :> t) - -let iter block = !current#iter block -let on_readable fd f = !current#on_readable fd f -let on_writable fd f = !current#on_writable fd f -let on_timer delay repeat f = !current#on_timer delay repeat f -let fake_io fd = !current#fake_io fd -let readable_count () = !current#readable_count -let writable_count () = !current#writable_count -let timer_count () = !current#timer_count -let fork () = !current#fork -let forwards_signal n = !current#forwards_signal n + if transfer then (Domain.DLS.get current)#transfer (engine : #t :> abstract); + if destroy then (Domain.DLS.get current)#destroy; + Domain.DLS.set current (engine : #t :> t) + +let iter block = (Domain.DLS.get current)#iter block +let on_readable fd f = (Domain.DLS.get current)#on_readable fd f +let on_writable fd f = (Domain.DLS.get current)#on_writable fd f +let on_timer delay repeat f = (Domain.DLS.get current)#on_timer delay repeat f +let fake_io fd = (Domain.DLS.get current)#fake_io fd +let readable_count () = (Domain.DLS.get current)#readable_count +let writable_count () = (Domain.DLS.get current)#writable_count +let timer_count () = (Domain.DLS.get current)#timer_count +let fork () = (Domain.DLS.get current)#fork +let forwards_signal n = (Domain.DLS.get current)#forwards_signal n module Versioned = struct diff --git a/src/unix/lwt_gc.ml b/src/unix/lwt_gc.ml index b0925f9dc..72ca8be0a 100644 --- a/src/unix/lwt_gc.ml +++ b/src/unix/lwt_gc.ml @@ -12,17 +12,18 @@ module Lwt_sequence = Lwt_sequence let ensure_termination t = if Lwt.state t = Lwt.Sleep then begin - let hook = - Lwt_sequence.add_l (fun _ -> t) Lwt_main.exit_hooks [@ocaml.warning "-3"] - in + let hook = Lwt_main.Exit_hooks.add_first (fun _ -> t) in (* Remove the hook when t has terminated *) ignore ( Lwt.finalize (fun () -> t) - (fun () -> Lwt_sequence.remove hook; Lwt.return_unit)) + (fun () -> Lwt_main.Exit_hooks.remove hook; Lwt.return_unit)) end let finaliser f = + (* In order for the domain id to be consistent, wherever the real finaliser is + called, we pass it in the continuation. *) + let domain_id = Domain.self () in (* In order not to create a reference to the value in the notification callback, we use an initially unset option cell which will be filled when the finaliser is called. *) @@ -30,6 +31,7 @@ let finaliser f = let id = Lwt_unix.make_notification ~once:true + domain_id (fun () -> match !opt with | None -> @@ -41,7 +43,7 @@ let finaliser f = (* The real finaliser: fill the cell and send a notification. *) (fun x -> opt := Some x; - Lwt_unix.send_notification id) + Lwt_unix.send_notification domain_id id) let finalise f x = Gc.finalise (finaliser f) x @@ -68,7 +70,7 @@ let foe_finaliser f called hook = finaliser (fun x -> (* Remove the exit hook, it is not needed anymore. *) - Lwt_sequence.remove hook; + Lwt_main.Exit_hooks.remove hook; (* Call the real finaliser. *) if !called then Lwt.return_unit @@ -83,8 +85,5 @@ let finalise_or_exit f x = let weak = Weak.create 1 in Weak.set weak 0 (Some x); let called = ref false in - let hook = - Lwt_sequence.add_l (foe_exit f called weak) Lwt_main.exit_hooks - [@ocaml.warning "-3"] - in + let hook = Lwt_main.Exit_hooks.add_first (foe_exit f called weak) in Gc.finalise (foe_finaliser f called hook) x diff --git a/src/unix/lwt_main.ml b/src/unix/lwt_main.ml index 823666e5f..e2ae48c3b 100644 --- a/src/unix/lwt_main.ml +++ b/src/unix/lwt_main.ml @@ -12,8 +12,8 @@ module Lwt_sequence = Lwt_sequence open Lwt.Infix -let enter_iter_hooks = Lwt_sequence.create () -let leave_iter_hooks = Lwt_sequence.create () +let enter_iter_hooks = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) +let leave_iter_hooks = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) let yield = Lwt.pause @@ -21,13 +21,24 @@ let abandon_yielded_and_paused () = Lwt.abandon_paused () let run p = + let domain_id = Domain.self () in + let () = if Lwt.is_alredy_registered domain_id then + () + else begin + let n = Lwt_unix.make_notification domain_id (fun () -> + let cbs = Lwt.get_sent_callbacks domain_id in + Lwt_sequence.iter_l (fun f -> f ()) cbs + ) in + Lwt.register_notification domain_id (fun () -> Lwt_unix.send_notification domain_id n) + end + in let rec run_loop () = match Lwt.poll p with | Some x -> x | None -> (* Call enter hooks. *) - Lwt_sequence.iter_l (fun f -> f ()) enter_iter_hooks; + Lwt_sequence.iter_l (fun f -> f ()) (Domain.DLS.get enter_iter_hooks); (* Do the main loop call. *) let should_block_waiting_for_io = Lwt.paused_count () = 0 in @@ -37,7 +48,7 @@ let run p = Lwt.wakeup_paused (); (* Call leave hooks. *) - Lwt_sequence.iter_l (fun f -> f ()) leave_iter_hooks; + Lwt_sequence.iter_l (fun f -> f ()) (Domain.DLS.get leave_iter_hooks); (* Repeat. *) run_loop () @@ -45,56 +56,54 @@ let run p = run_loop () -let run_already_called = ref `No -let run_already_called_mutex = Mutex.create () +let run_already_called = Domain.DLS.new_key (fun () -> `No) +let run_already_called_mutex = Domain.DLS.new_key (fun () -> Mutex.create ()) let finished () = - Mutex.lock run_already_called_mutex; - run_already_called := `No; - Mutex.unlock run_already_called_mutex + Mutex.lock (Domain.DLS.get run_already_called_mutex); + Domain.DLS.set run_already_called `No; + Mutex.unlock (Domain.DLS.get run_already_called_mutex) let run p = (* Fail in case a call to Lwt_main.run is nested under another invocation of Lwt_main.run. *) - Mutex.lock run_already_called_mutex; - + Mutex.lock (Domain.DLS.get run_already_called_mutex); let error_message_if_call_is_nested = - match !run_already_called with - (* `From is effectively disabled for the time being, because there is a bug, - present in all versions of OCaml supported by Lwt, where, with the - bytecode runtime, if one changes the working directory and then attempts - to retrieve the backtrace, the runtime calls [abort] at the C level and - exits the program ungracefully. It is especially likely that a daemon - would change directory before calling [Lwt_main.run], so we can't have it - retrieving the backtrace, even though a daemon is not likely to be - compiled to bytecode. - - This can be addressed with detection. Starting with 4.04, there is a - type [Sys.backend_type] that could be used. *) - | `From backtrace_string -> - Some (Printf.sprintf "%s\n%s\n%s" - "Nested calls to Lwt_main.run are not allowed" - "Lwt_main.run already called from:" - backtrace_string) - | `From_somewhere -> - Some ("Nested calls to Lwt_main.run are not allowed") - | `No -> - let called_from = - (* See comment above. - if Printexc.backtrace_status () then - let backtrace = - try raise Exit - with Exit -> Printexc.get_backtrace () - in - `From backtrace - else *) - `From_somewhere - in - run_already_called := called_from; - None + match (Domain.DLS.get run_already_called) with + (* `From is effectively disabled for the time being, because there is a bug, + present in all versions of OCaml supported by Lwt, where, with the + bytecode runtime, if one changes the working directory and then attempts + to retrieve the backtrace, the runtime calls [abort] at the C level and + exits the program ungracefully. It is especially likely that a daemon + would change directory before calling [Lwt_main.run], so we can't have it + retrieving the backtrace, even though a daemon is not likely to be + compiled to bytecode. + + This can be addressed with detection. Starting with 4.04, there is a + type [Sys.backend_type] that could be used. *) + | `From backtrace_string -> + Some (Printf.sprintf "%s\n%s\n%s" + "Nested calls to Lwt_main.run are not allowed" + "Lwt_main.run already called from:" + backtrace_string) + | `From_somewhere -> + Some ("Nested calls to Lwt_main.run are not allowed") + | `No -> + let called_from = + (* See comment above. + if Printexc.backtrace_status () then + let backtrace = + try raise Exit + with Exit -> Printexc.get_backtrace () + in + `From backtrace + else *) + `From_somewhere + in + Domain.DLS.set run_already_called called_from; + None in - - Mutex.unlock run_already_called_mutex; + Mutex.unlock (Domain.DLS.get run_already_called_mutex); begin match error_message_if_call_is_nested with | Some message -> failwith message @@ -109,10 +118,10 @@ let run p = finished (); raise exn -let exit_hooks = Lwt_sequence.create () +let exit_hooks = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) let rec call_hooks () = - match Lwt_sequence.take_opt_l exit_hooks with + match Lwt_sequence.take_opt_l (Domain.DLS.get exit_hooks) with | None -> Lwt.return_unit | Some f -> @@ -123,13 +132,13 @@ let rec call_hooks () = let () = at_exit (fun () -> - if not (Lwt_sequence.is_empty exit_hooks) then begin + if not (Lwt_sequence.is_empty (Domain.DLS.get exit_hooks)) then begin Lwt.abandon_wakeups (); finished (); run (call_hooks ()) end) -let at_exit f = ignore (Lwt_sequence.add_l f exit_hooks) +let at_exit f = ignore (Lwt_sequence.add_l f (Domain.DLS.get exit_hooks)) module type Hooks = sig @@ -145,7 +154,7 @@ end module type Hook_sequence = sig type 'return_value kind - val sequence : (unit -> unit kind) Lwt_sequence.t + val sequence : (unit -> unit kind) Lwt_sequence.t Domain.DLS.key end module Wrap_hooks (Sequence : Hook_sequence) = @@ -154,18 +163,18 @@ struct type hook = (unit -> unit Sequence.kind) Lwt_sequence.node let add_first hook_fn = - let hook_node = Lwt_sequence.add_l hook_fn Sequence.sequence in + let hook_node = Lwt_sequence.add_l hook_fn (Domain.DLS.get Sequence.sequence) in hook_node let add_last hook_fn = - let hook_node = Lwt_sequence.add_r hook_fn Sequence.sequence in + let hook_node = Lwt_sequence.add_r hook_fn (Domain.DLS.get Sequence.sequence) in hook_node let remove hook_node = Lwt_sequence.remove hook_node let remove_all () = - Lwt_sequence.iter_node_l Lwt_sequence.remove Sequence.sequence + Lwt_sequence.iter_node_l Lwt_sequence.remove (Domain.DLS.get Sequence.sequence) end module Enter_iter_hooks = diff --git a/src/unix/lwt_main.mli b/src/unix/lwt_main.mli index f2ebde219..60c843ba2 100644 --- a/src/unix/lwt_main.mli +++ b/src/unix/lwt_main.mli @@ -126,29 +126,6 @@ module Leave_iter_hooks : module Exit_hooks : Hooks with type 'return_value kind = 'return_value Lwt.t - - -[@@@ocaml.warning "-3"] - -val enter_iter_hooks : (unit -> unit) Lwt_sequence.t - [@@ocaml.deprecated - " Use module Lwt_main.Enter_iter_hooks."] -(** @deprecated Use module {!Enter_iter_hooks}. *) - -val leave_iter_hooks : (unit -> unit) Lwt_sequence.t - [@@ocaml.deprecated - " Use module Lwt_main.Leave_iter_hooks."] -(** @deprecated Use module {!Leave_iter_hooks}. *) - -val exit_hooks : (unit -> unit Lwt.t) Lwt_sequence.t - [@@ocaml.deprecated - " Use module Lwt_main.Exit_hooks."] -(** @deprecated Use module {!Exit_hooks}. *) - -[@@@ocaml.warning "+3"] - - - val at_exit : (unit -> unit Lwt.t) -> unit (** [Lwt_main.at_exit hook] is the same as [ignore (Lwt_main.Exit_hooks.add_first hook)]. *) diff --git a/src/unix/lwt_preemptive.ml b/src/unix/lwt_preemptive.ml index eacf32f28..691160e3c 100644 --- a/src/unix/lwt_preemptive.ml +++ b/src/unix/lwt_preemptive.ml @@ -17,23 +17,23 @@ open Lwt.Infix +-----------------------------------------------------------------+ *) (* Minimum number of preemptive threads: *) -let min_threads : int ref = ref 0 +let min_threads : int Atomic.t = Atomic.make 0 (* Maximum number of preemptive threads: *) -let max_threads : int ref = ref 0 +let max_threads : int Atomic.t = Atomic.make 0 (* Size of the waiting queue: *) -let max_thread_queued = ref 1000 +let max_thread_queued = Atomic.make 1000 let get_max_number_of_threads_queued _ = - !max_thread_queued + Atomic.get max_thread_queued let set_max_number_of_threads_queued n = if n < 0 then invalid_arg "Lwt_preemptive.set_max_number_of_threads_queued"; - max_thread_queued := n + Atomic.set max_thread_queued n (* The total number of preemptive threads currently running: *) -let threads_count = ref 0 +let threads_count = Atomic.make 0 (* +-----------------------------------------------------------------+ | Preemptive threads management | @@ -102,14 +102,14 @@ let rec worker_loop worker = task (); (* If there is too much threads, exit. This can happen if the user decreased the maximum: *) - if !threads_count > !max_threads then worker.reuse <- false; + if Atomic.get threads_count > Atomic.get max_threads then worker.reuse <- false; (* Tell the main thread that work is done: *) - Lwt_unix.send_notification id; + Lwt_unix.send_notification (Domain.self ()) id; if worker.reuse then worker_loop worker (* create a new worker: *) let make_worker () = - incr threads_count; + Atomic.incr threads_count; let worker = { task_cell = CELL.make (); thread = Thread.self (); @@ -130,7 +130,7 @@ let add_worker worker = let get_worker () = if not (Queue.is_empty workers) then Lwt.return (Queue.take workers) - else if !threads_count < !max_threads then + else if Atomic.get threads_count < Atomic.get max_threads then Lwt.return (make_worker ()) else (Lwt.add_task_r [@ocaml.warning "-3"]) waiters @@ -139,33 +139,33 @@ let get_worker () = | Initialisation, and dynamic parameters reset | +-----------------------------------------------------------------+ *) -let get_bounds () = (!min_threads, !max_threads) +let get_bounds () = (Atomic.get min_threads, Atomic.get max_threads) let set_bounds (min, max) = if min < 0 || max < min then invalid_arg "Lwt_preemptive.set_bounds"; - let diff = min - !threads_count in - min_threads := min; - max_threads := max; + let diff = min - Atomic.get threads_count in + Atomic.set min_threads min; + Atomic.set max_threads max; (* Launch new workers: *) for _i = 1 to diff do add_worker (make_worker ()) done -let initialized = ref false +let initialized = Atomic.make false let init min max _errlog = - initialized := true; + Atomic.set initialized true; set_bounds (min, max) let simple_init () = - if not !initialized then begin - initialized := true; + if not (Atomic.get initialized) then begin + Atomic.set initialized true; set_bounds (0, 4) end -let nbthreads () = !threads_count +let nbthreads () = Atomic.get threads_count let nbthreadsqueued () = Lwt_sequence.fold_l (fun _ x -> x + 1) waiters 0 -let nbthreadsbusy () = !threads_count - Queue.length workers +let nbthreadsbusy () = Atomic.get threads_count - Queue.length workers (* +-----------------------------------------------------------------+ | Detaching | @@ -186,7 +186,7 @@ let detach f args = get_worker () >>= fun worker -> let waiter, wakener = Lwt.wait () in let id = - Lwt_unix.make_notification ~once:true + Lwt_unix.make_notification ~once:true (Domain.self ()) (fun () -> Lwt.wakeup_result wakener !result) in Lwt.finalize @@ -199,7 +199,7 @@ let detach f args = (* Put back the worker to the pool: *) add_worker worker else begin - decr threads_count; + Atomic.decr threads_count; (* Or wait for the thread to terminates, to free its associated resources: *) Thread.join worker.thread @@ -217,7 +217,7 @@ let jobs = Queue.create () let jobs_mutex = Mutex.create () let job_notification = - Lwt_unix.make_notification + Lwt_unix.make_notification (Domain.self ()) (fun () -> (* Take the first job. The queue is never empty at this point. *) @@ -226,20 +226,20 @@ let job_notification = Mutex.unlock jobs_mutex; ignore (thunk ())) -let run_in_main_dont_wait f = +let run_in_domain_dont_wait d f = (* Add the job to the queue. *) Mutex.lock jobs_mutex; Queue.add f jobs; Mutex.unlock jobs_mutex; (* Notify the main thread. *) - Lwt_unix.send_notification job_notification + Lwt_unix.send_notification d job_notification (* There is a potential performance issue from creating a cell every time this function is called. See: https://github.com/ocsigen/lwt/issues/218 https://github.com/ocsigen/lwt/pull/219 https://github.com/ocaml/ocaml/issues/7158 *) -let run_in_main f = +let run_in_domain d f = let cell = CELL.make () in (* Create the job. *) let job () = @@ -251,13 +251,13 @@ let run_in_main f = CELL.set cell result; Lwt.return_unit in - run_in_main_dont_wait job; + run_in_domain_dont_wait d job; (* Wait for the result. *) match CELL.get cell with | Result.Ok ret -> ret | Result.Error exn -> raise exn (* This version shadows the one above, adding an exception handler *) -let run_in_main_dont_wait f handler = +let run_in_domain_dont_wait d f handler = let f () = Lwt.catch f (fun exc -> handler exc; Lwt.return_unit) in - run_in_main_dont_wait f + run_in_domain_dont_wait d f diff --git a/src/unix/lwt_preemptive.mli b/src/unix/lwt_preemptive.mli index 24077350f..df3cdda3c 100644 --- a/src/unix/lwt_preemptive.mli +++ b/src/unix/lwt_preemptive.mli @@ -21,21 +21,21 @@ val detach : ('a -> 'b) -> 'a -> 'b Lwt.t Note that Lwt thread-local storage (i.e., {!Lwt.with_value}) cannot be safely used from within [f]. The same goes for most of the rest of Lwt. If - you need to run an Lwt thread in [f], use {!run_in_main}. *) + you need to run an Lwt thread in [f], use {!run_in_domain}. *) -val run_in_main : (unit -> 'a Lwt.t) -> 'a - (** [run_in_main f] can be called from a detached computation to execute +val run_in_domain : Domain.id -> (unit -> 'a Lwt.t) -> 'a + (** [run_in_domain f] can be called from a detached computation to execute [f ()] in the main preemptive thread, i.e. the one executing - {!Lwt_main.run}. [run_in_main f] blocks until [f ()] completes, then - returns its result. If [f ()] raises an exception, [run_in_main f] raises + {!Lwt_main.run}. [run_in_domain f] blocks until [f ()] completes, then + returns its result. If [f ()] raises an exception, [run_in_domain f] raises the same exception. {!Lwt.with_value} may be used inside [f ()]. {!Lwt.get} can correctly retrieve values set this way inside [f ()], but not values set using {!Lwt.with_value} outside [f ()]. *) -val run_in_main_dont_wait : (unit -> unit Lwt.t) -> (exn -> unit) -> unit -(** [run_in_main_dont_wait f h] does the same as [run_in_main f] but a bit faster +val run_in_domain_dont_wait : Domain.id -> (unit -> unit Lwt.t) -> (exn -> unit) -> unit +(** [run_in_domain_dont_wait f h] does the same as [run_in_domain f] but a bit faster and lighter as it does not wait for the result of [f]. If [f]'s promise is rejected (or if it raises), then the function [h] is diff --git a/src/unix/lwt_unix.cppo.ml b/src/unix/lwt_unix.cppo.ml index 6fb9f8044..d757b9f08 100644 --- a/src/unix/lwt_unix.cppo.ml +++ b/src/unix/lwt_unix.cppo.ml @@ -21,17 +21,17 @@ type async_method = | Async_detach | Async_switch -let default_async_method_var = ref Async_detach +let default_async_method_var = Atomic.make Async_detach let () = try match Sys.getenv "LWT_ASYNC_METHOD" with | "none" -> - default_async_method_var := Async_none + Atomic.set default_async_method_var Async_none | "detach" -> - default_async_method_var := Async_detach + Atomic.set default_async_method_var Async_detach | "switch" -> - default_async_method_var := Async_switch + Atomic.set default_async_method_var Async_switch | str -> Printf.eprintf "%s: invalid lwt async method: '%s', must be 'none', 'detach' or 'switch'\n%!" @@ -39,15 +39,15 @@ let () = with Not_found -> () -let default_async_method () = !default_async_method_var -let set_default_async_method am = default_async_method_var := am +let default_async_method () = Atomic.get default_async_method_var +let set_default_async_method am = Atomic.set default_async_method_var am let async_method_key = Lwt.new_key () let async_method () = match Lwt.get async_method_key with | Some am -> am - | None -> !default_async_method_var + | None -> Atomic.get default_async_method_var let with_async_none f = Lwt.with_value async_method_key (Some Async_none) f @@ -78,38 +78,52 @@ module Notifiers = Hashtbl.Make(struct let hash (x : int) = x end) -let notifiers = Notifiers.create 1024 +let notifiers = Domain_map.create_protected_map () (* See https://github.com/ocsigen/lwt/issues/277 and https://github.com/ocsigen/lwt/pull/278. *) -let current_notification_id = ref (0x7FFFFFFF - 1000) +let current_notification_id = Atomic.make (0x7FFFFFFF - 1000) -let rec find_free_id id = - if Notifiers.mem notifiers id then - find_free_id (id + 1) - else - id - -let make_notification ?(once=false) f = - let id = find_free_id (!current_notification_id + 1) in - current_notification_id := id; - Notifiers.add notifiers id { notify_once = once; notify_handler = f }; +let make_notification ?(once=false) domain_id f = + let id = Atomic.fetch_and_add current_notification_id 1 in + Domain_map.update notifiers domain_id + (function + | None -> + let notifiers = Notifiers.create 1024 in + Notifiers.add notifiers id { notify_once = once; notify_handler = f }; + Some notifiers + | Some notifiers -> + Notifiers.add notifiers id { notify_once = once; notify_handler = f }; + Some notifiers); id -let stop_notification id = - Notifiers.remove notifiers id - -let set_notification id f = - let notifier = Notifiers.find notifiers id in - Notifiers.replace notifiers id { notifier with notify_handler = f } +let stop_notification domain_id id = + Domain_map.update notifiers domain_id + (function + | None -> None + | Some notifiers -> + Notifiers.remove notifiers id; + Some notifiers) -let call_notification id = - match Notifiers.find notifiers id with - | exception Not_found -> () - | notifier -> - if notifier.notify_once then - stop_notification id; - notifier.notify_handler () +let set_notification domain_id id f = + Domain_map.update notifiers domain_id + (function + | None -> raise Not_found + | Some notifiers -> + let notifier = Notifiers.find notifiers id in + Notifiers.replace notifiers id { notifier with notify_handler = f }; + Some notifiers) + +let call_notification domain_id id = + match Domain_map.find notifiers domain_id with + | None -> () + | Some notifiers -> + (match Notifiers.find notifiers id with + | exception Not_found -> () + | notifier -> + if notifier.notify_once then + Notifiers.remove notifiers id; + notifier.notify_handler ()) (* +-----------------------------------------------------------------+ | Sleepers | @@ -178,13 +192,8 @@ let cancel_jobs () = abort_jobs Lwt.Canceled let wait_for_jobs () = Lwt.join (Lwt_sequence.fold_l (fun (w, _) l -> w :: l) jobs []) -let wrap_result f x = - try - Result.Ok (f x) - with exn when Lwt.Exception_filter.run exn -> - Result.Error exn - let run_job_aux async_method job result = + let domain_id = Domain.self () in (* Starts the job. *) if start_job job async_method then (* The job has already terminated, read and return the result @@ -201,7 +210,7 @@ let run_job_aux async_method job result = ignore begin (* Create the notification for asynchronous wakeup. *) let id = - make_notification ~once:true + make_notification ~once:true domain_id (fun () -> Lwt_sequence.remove node; let result = result job in @@ -211,7 +220,7 @@ let run_job_aux async_method job result = notification. *) Lwt.pause () >>= fun () -> (* The job has terminated, send the result immediately. *) - if check_job job id then call_notification id; + if check_job job id then call_notification domain_id id; Lwt.return_unit end; waiter @@ -223,12 +232,7 @@ let choose_async_method = function | None -> match Lwt.get async_method_key with | Some am -> am - | None -> !default_async_method_var - -let execute_job ?async_method ~job ~result ~free = - let async_method = choose_async_method async_method in - run_job_aux async_method job (fun job -> let x = wrap_result result job in free job; x) -[@@ocaml.warning "-16"] + | None -> Atomic.get default_async_method_var external self_result : 'a job -> 'a = "lwt_unix_self_result" (* returns the result of a job using the [result] field of the C @@ -243,22 +247,7 @@ let self_result job = with exn when Lwt.Exception_filter.run exn -> Result.Error exn -let in_retention_test = ref false - -let retained o = - let retained = ref true in - Gc.finalise (fun _ -> - if !in_retention_test then - retained := false) - o; - in_retention_test := true; - retained - let run_job ?async_method job = - if !in_retention_test then begin - Gc.full_major (); - in_retention_test := false - end; let async_method = choose_async_method async_method in if async_method = Async_none then try @@ -2208,20 +2197,32 @@ let tcflow ch act = | Reading notifications | +-----------------------------------------------------------------+ *) -external init_notification : unit -> Unix.file_descr = "lwt_unix_init_notification" -external send_notification : int -> unit = "lwt_unix_send_notification_stub" -external recv_notifications : unit -> int array = "lwt_unix_recv_notifications" +external init_notification : Domain.id -> Unix.file_descr = "lwt_unix_init_notification_stub" +external send_notification : Domain.id -> int -> unit = "lwt_unix_send_notification_stub" +external recv_notifications : Domain.id -> int array = "lwt_unix_recv_notifications_stub" + +let handle_notifications domain_id (_ : Lwt_engine.event) = + Array.iter (call_notification domain_id) (recv_notifications domain_id) -let handle_notifications _ = - (* Process available notifications. *) - Array.iter call_notification (recv_notifications ()) +let event_notifications = Domain_map.create_protected_map () -let event_notifications = ref (Lwt_engine.on_readable (init_notification ()) handle_notifications) +let init_domain () = + let domain_id = Domain.self () in + let _ : notifier Notifiers.t = (Domain_map.init notifiers domain_id (fun () -> Notifiers.create 1024)) in + let _ : Lwt_engine.event = Domain_map.init event_notifications domain_id (fun () -> + let eventfd = init_notification domain_id in + Lwt_engine.on_readable eventfd (handle_notifications domain_id)) + in + () (* +-----------------------------------------------------------------+ | Signals | +-----------------------------------------------------------------+ *) +(* TODO: should all notifications for signals be on domain0? or should each + domain be able to install their own signal handler? what domain receives a + signal? *) + external set_signal : int -> int -> bool -> unit = "lwt_unix_set_signal" external remove_signal : int -> bool -> unit = "lwt_unix_remove_signal" external init_signals : unit -> unit = "lwt_unix_init_signals" @@ -2244,6 +2245,7 @@ type signal_handler = { and signal_handler_id = signal_handler option ref +(* TODO: what to do about signals? *) let signals = ref Signal_map.empty let signal_count () = Signal_map.fold @@ -2259,7 +2261,7 @@ let on_signal_full signum handler = with Not_found -> let actions = Lwt_sequence.create () in let notification = - make_notification + make_notification (Domain.self ()) (fun () -> Lwt_sequence.iter_l (fun f -> f id signum) @@ -2268,7 +2270,7 @@ let on_signal_full signum handler = (try set_signal signum notification with exn when Lwt.Exception_filter.run exn -> - stop_notification notification; + stop_notification (Domain.self ()) notification; raise exn); signals := Signal_map.add signum (notification, actions) !signals; (notification, actions) @@ -2290,7 +2292,7 @@ let disable_signal_handler id = if Lwt_sequence.is_empty actions then begin remove_signal sh.sh_num; signals := Signal_map.remove sh.sh_num !signals; - stop_notification notification + stop_notification (Domain.self ()) notification end let reinstall_signal_handler signum = @@ -2313,16 +2315,20 @@ let fork () = (* Reset threading. *) reset_after_fork (); (* Stop the old event for notifications. *) - Lwt_engine.stop_event !event_notifications; + let domain_id = Domain.self () in + (match Domain_map.find event_notifications domain_id with + | Some event -> Lwt_engine.stop_event event + | None -> ()); (* Reinitialise the notification system. *) - event_notifications := Lwt_engine.on_readable (init_notification ()) handle_notifications; + let new_event = Lwt_engine.on_readable (init_notification domain_id) (handle_notifications domain_id) in + Domain_map.add event_notifications domain_id new_event; (* Collect all pending jobs. *) let l = Lwt_sequence.fold_l (fun (_, f) l -> f :: l) jobs [] in (* Remove them all. *) Lwt_sequence.iter_node_l Lwt_sequence.remove jobs; (* And cancel them all. We yield first so that if the program do an exec just after, it won't be executed. *) - Lwt.on_termination (Lwt_main.yield () [@warning "-3"]) (fun () -> List.iter (fun f -> f Lwt.Canceled) l); + Lwt.on_termination (Lwt.pause ()) (fun () -> List.iter (fun f -> f Lwt.Canceled) l); 0 | pid -> pid @@ -2355,6 +2361,7 @@ let do_wait4 flags pid = let wait_children = Lwt_sequence.create () let wait_count () = Lwt_sequence.length wait_children +(* TODO: what to do about signals? especially sigchld signal? *) let sigchld_handler_installed = ref false let install_sigchld_handler () = @@ -2383,9 +2390,12 @@ let install_sigchld_handler () = install the SIGCHLD handler, in order to cause any EINTR-unsafe code to fail (as it should). *) let () = - Lwt.async (fun () -> - Lwt.pause () >|= fun () -> - install_sigchld_handler ()) + (* TODO: figure out what to do about signals *) + (* TODO: this interferes with tests because it leaves a pause hanging? *) + if (Domain.self () :> int) = 0 then + Lwt.async (fun () -> + Lwt.pause () >|= fun () -> + install_sigchld_handler ()) let _waitpid flags pid = Lwt.catch @@ -2462,8 +2472,6 @@ let system cmd = | Misc | +-----------------------------------------------------------------+ *) -let run = Lwt_main.run - let handle_unix_error f x = Lwt.catch (fun () -> f x) diff --git a/src/unix/lwt_unix.cppo.mli b/src/unix/lwt_unix.cppo.mli index c36d9a470..d400ff244 100644 --- a/src/unix/lwt_unix.cppo.mli +++ b/src/unix/lwt_unix.cppo.mli @@ -211,8 +211,7 @@ val fork : unit -> int - None of the above is necessary if you intend to call [exec]. Indeed, in that case, it is not even necessary to use [Lwt_unix.fork]. You can use {!Unix.fork}. - - To abandon some more promises, see - {!Lwt_main.abandon_yielded_and_paused}. *) + - To abandon some more promises, see {!Lwt.abandon_paused}. *) type process_status = Unix.process_status = @@ -1458,20 +1457,12 @@ val cancel_jobs : unit -> unit val wait_for_jobs : unit -> unit Lwt.t (** Wait for all pending jobs to terminate. *) -val execute_job : - ?async_method : async_method -> - job : 'a job -> - result : ('a job -> 'b) -> - free : ('a job -> unit) -> 'b Lwt.t - [@@ocaml.deprecated " Use Lwt_unix.run_job."] - (** @deprecated Use [run_job]. *) - (** {2 Notifications} *) (** Lwt internally use a pipe to send notification to the main thread. The following functions allow to use this pipe. *) -val make_notification : ?once : bool -> (unit -> unit) -> int +val make_notification : ?once : bool -> Domain.id -> (unit -> unit) -> int (** [make_notification ?once f] registers a new notifier. It returns the id of the notifier. Each time a notification with this id is received, [f] is called. @@ -1479,25 +1470,28 @@ val make_notification : ?once : bool -> (unit -> unit) -> int if [once] is specified, then the notification is stopped after the first time it is received. It defaults to [false]. *) -val send_notification : int -> unit +val send_notification : Domain.id -> int -> unit (** [send_notification id] sends a notification. This function is thread-safe. *) -val stop_notification : int -> unit +val stop_notification : Domain.id -> int -> unit (** Stop the given notification. Note that you should not reuse the id after the notification has been stopped, the result is unspecified if you do so. *) -val call_notification : int -> unit +val call_notification : Domain.id -> int -> unit (** Call the handler associated to the given notification. Note that if the notification was defined with [once = true] it is removed. *) -val set_notification : int -> (unit -> unit) -> unit +val set_notification : Domain.id -> int -> (unit -> unit) -> unit (** [set_notification id f] replace the function associated to the notification by [f]. It raises [Not_found] if the given notification is not found. *) +val init_domain : unit -> unit + (** call when Domain.spawn! and call on domain0 too, don't call twice for the same domain *) + (** {2 System threads pool} *) (** If the program is using the async method [Async_detach] or @@ -1579,10 +1573,6 @@ end (**/**) -val run : 'a Lwt.t -> 'a - [@@ocaml.deprecated " Use Lwt_main.run."] - (** @deprecated Use [Lwt_main.run]. *) - val has_wait4 : bool [@@ocaml.deprecated " Use Lwt_sys.have `wait4."] (** @deprecated Use [Lwt_sys.have `wait4]. *) @@ -1591,9 +1581,6 @@ val somaxconn : unit -> int [@@ocaml.deprecated " This is an internal function."] (** @deprecated This is for internal use only. *) -val retained : 'a -> bool ref - (** @deprecated Used for testing. *) - val read_bigarray : string -> file_descr -> IO_vectors._bigarray -> int -> int -> int Lwt.t [@@ocaml.deprecated " This is an internal function."] diff --git a/src/unix/lwt_unix.h b/src/unix/lwt_unix.h index ab4ad64bf..389082fda 100644 --- a/src/unix/lwt_unix.h +++ b/src/unix/lwt_unix.h @@ -95,7 +95,7 @@ void lwt_unix_not_available(char const *feature) Noreturn; +-----------------------------------------------------------------+ */ /* Sends a notification for the given id. */ -void lwt_unix_send_notification(intnat id); +void lwt_unix_send_notification(intnat domain_id, intnat id); /* +-----------------------------------------------------------------+ | Threading | @@ -196,6 +196,7 @@ struct lwt_unix_job { /* Id used to notify the main thread in case the job do not terminate immediately. */ + intnat domain_id; intnat notification_id; /* The function to call to do the work. diff --git a/src/unix/lwt_unix_stubs.c b/src/unix/lwt_unix_stubs.c index 443773bac..cb848f89f 100644 --- a/src/unix/lwt_unix_stubs.c +++ b/src/unix/lwt_unix_stubs.c @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -495,15 +496,6 @@ CAMLprim value lwt_unix_socketpair_stub(value cloexec, value domain, value type, /* The mutex used to send and receive notifications. */ static lwt_unix_mutex notification_mutex; -/* All pending notifications. */ -static intnat *notifications = NULL; - -/* The size of the notification buffer. */ -static long notification_count = 0; - -/* The index to the next available cell in the notification buffer. */ -static long notification_index = 0; - /* The mode currently used for notifications. */ enum notification_mode { /* Not yet initialized. */ @@ -522,35 +514,59 @@ enum notification_mode { NOTIFICATION_MODE_WINDOWS }; -/* The current notification mode. */ -static enum notification_mode notification_mode = - NOTIFICATION_MODE_NOT_INITIALIZED; +/* Domain-specific notification state */ +struct domain_notification_state { + intnat *notifications; + long notification_count; + long notification_index; + enum notification_mode notification_mode; +#if defined(HAVE_EVENTFD) + int notification_fd; +#endif + int notification_fds[2]; +}; + +/* table to store per-domain notification state */ +#define MAX_DOMAINS 64 // TODO: review values +static struct domain_notification_state domain_states[MAX_DOMAINS]; +static int domain_states_initialized[MAX_DOMAINS] = {0}; /* Send one notification. */ -static int (*notification_send)(); +static int (*notification_send)(int domain_id); /* Read one notification. */ -static int (*notification_recv)(); +static int (*notification_recv)(int domain_id); static void init_notifications() { lwt_unix_mutex_init(¬ification_mutex); - notification_count = 4096; - notifications = - (intnat *)lwt_unix_malloc(notification_count * sizeof(intnat)); } -static void resize_notifications() { - long new_notification_count = notification_count * 2; - intnat *new_notifications = - (intnat *)lwt_unix_malloc(new_notification_count * sizeof(intnat)); - memcpy((void *)new_notifications, (void *)notifications, - notification_count * sizeof(intnat)); - free(notifications); - notifications = new_notifications; - notification_count = new_notification_count; +static void init_domain_notifications(int domain_id) { + if (domain_id >= 0 && domain_id < MAX_DOMAINS && !domain_states_initialized[domain_id]) { + domain_states[domain_id].notification_count = 4096; + domain_states[domain_id].notifications = + (intnat *)lwt_unix_malloc(domain_states[domain_id].notification_count * sizeof(intnat)); + domain_states[domain_id].notification_index = 0; + domain_states[domain_id].notification_mode = NOTIFICATION_MODE_NOT_INITIALIZED; + domain_states_initialized[domain_id] = 1; + } +} + +static void resize_notifications(int domain_id) { + if (domain_id >= 0 && domain_id < MAX_DOMAINS && domain_states_initialized[domain_id]) { + struct domain_notification_state *state = &domain_states[domain_id]; + long new_notification_count = state->notification_count * 2; + intnat *new_notifications = + (intnat *)lwt_unix_malloc(new_notification_count * sizeof(intnat)); + memcpy((void *)new_notifications, (void *)state->notifications, + state->notification_count * sizeof(intnat)); + free(state->notifications); + state->notifications = new_notifications; + state->notification_count = new_notification_count; + } } -void lwt_unix_send_notification(intnat id) { +void lwt_unix_send_notification(intnat domain_id, intnat id) { int ret; #if !defined(LWT_ON_WINDOWS) sigset_t new_mask; @@ -561,33 +577,37 @@ void lwt_unix_send_notification(intnat id) { #else DWORD error; #endif + init_domain_notifications(domain_id); lwt_unix_mutex_lock(¬ification_mutex); - if (notification_index > 0) { - /* There is already a pending notification in the buffer, no - need to signal the main thread. */ - if (notification_index == notification_count) resize_notifications(); - notifications[notification_index++] = id; - } else { - /* There is none, notify the main thread. */ - notifications[notification_index++] = id; - ret = notification_send(); + if (domain_id >= 0 && domain_id < MAX_DOMAINS && domain_states_initialized[domain_id]) { + struct domain_notification_state *state = &domain_states[domain_id]; + if (state->notification_index > 0) { + /* There is already a pending notification in the buffer, no + need to signal the main thread. */ + if (state->notification_index == state->notification_count) resize_notifications(domain_id); + state->notifications[state->notification_index++] = id; + } else { + /* There is none, notify the main thread. */ + state->notifications[state->notification_index++] = id; + ret = notification_send(domain_id); #if defined(LWT_ON_WINDOWS) - if (ret == SOCKET_ERROR) { - error = WSAGetLastError(); - if (error != WSANOTINITIALISED) { - lwt_unix_mutex_unlock(¬ification_mutex); - win32_maperr(error); - uerror("send_notification", Nothing); - } /* else we're probably shutting down, so ignore the error */ - } + if (ret == SOCKET_ERROR) { + error = WSAGetLastError(); + if (error != WSANOTINITIALISED) { + lwt_unix_mutex_unlock(¬ification_mutex); + win32_maperr(error); + uerror("send_notification", Nothing); + } /* else we're probably shutting down, so ignore the error */ + } #else - if (ret < 0) { - error = errno; - lwt_unix_mutex_unlock(¬ification_mutex); - pthread_sigmask(SIG_SETMASK, &old_mask, NULL); - unix_error(error, "send_notification", Nothing); - } + if (ret < 0) { + error = errno; + lwt_unix_mutex_unlock(¬ification_mutex); + pthread_sigmask(SIG_SETMASK, &old_mask, NULL); + unix_error(error, "send_notification", Nothing); + } #endif + } } lwt_unix_mutex_unlock(¬ification_mutex); #if !defined(LWT_ON_WINDOWS) @@ -595,12 +615,12 @@ void lwt_unix_send_notification(intnat id) { #endif } -value lwt_unix_send_notification_stub(value id) { - lwt_unix_send_notification(Long_val(id)); +value lwt_unix_send_notification_stub(value domain_id, value id) { + lwt_unix_send_notification(Long_val(domain_id), Long_val(id)); return Val_unit; } -value lwt_unix_recv_notifications() { +value lwt_unix_recv_notifications(intnat domain_id) { int ret, i, current_index; value result; #if !defined(LWT_ON_WINDOWS) @@ -612,9 +632,11 @@ value lwt_unix_recv_notifications() { #else DWORD error; #endif + /* Initialize domain state if needed */ + init_domain_notifications(domain_id); lwt_unix_mutex_lock(¬ification_mutex); /* Receive the signal. */ - ret = notification_recv(); + ret = notification_recv(domain_id); #if defined(LWT_ON_WINDOWS) if (ret == SOCKET_ERROR) { error = WSAGetLastError(); @@ -631,25 +653,35 @@ value lwt_unix_recv_notifications() { } #endif - do { - /* - release the mutex while calling caml_alloc, - which may call gc and switch the thread, - resulting in a classical deadlock, - when thread in question tries another send - */ - current_index = notification_index; + if (domain_id >= 0 && domain_id < MAX_DOMAINS && domain_states_initialized[domain_id]) { + struct domain_notification_state *state = &domain_states[domain_id]; + + do { + /* + release the mutex while calling caml_alloc, + which may call gc and switch the thread, + resulting in a classical deadlock, + when thread in question tries another send + */ + current_index = state->notification_index; + lwt_unix_mutex_unlock(¬ification_mutex); + result = caml_alloc_tuple(current_index); + lwt_unix_mutex_lock(¬ification_mutex); + /* check that no new notifications appeared meanwhile (rare) */ + } while (current_index != state->notification_index); + + /* Read all pending notifications. */ + for (i = 0; i < state->notification_index; i++) { + Field(result, i) = Val_long(state->notifications[i]); + } + /* Reset the index. */ + state->notification_index = 0; + } else { + /* Domain not initialized, return empty array */ lwt_unix_mutex_unlock(¬ification_mutex); - result = caml_alloc_tuple(current_index); + result = caml_alloc_tuple(0); lwt_unix_mutex_lock(¬ification_mutex); - /* check that no new notifications appeared meanwhile (rare) */ - } while (current_index != notification_index); - - /* Read all pending notifications. */ - for (i = 0; i < notification_index; i++) - Field(result, i) = Val_long(notifications[i]); - /* Reset the index. */ - notification_index = 0; + } lwt_unix_mutex_unlock(¬ification_mutex); #if !defined(LWT_ON_WINDOWS) pthread_sigmask(SIG_SETMASK, &old_mask, NULL); @@ -657,21 +689,26 @@ value lwt_unix_recv_notifications() { return result; } +value lwt_unix_recv_notifications_stub(value domain_id) { + value res = lwt_unix_recv_notifications(Long_val(domain_id)); + return res; +} + #if defined(LWT_ON_WINDOWS) static SOCKET socket_r, socket_w; -static int windows_notification_send() { +static int windows_notification_send(int domain_id) { char buf = '!'; return send(socket_w, &buf, 1, 0); } -static int windows_notification_recv() { +static int windows_notification_recv(int domain_id) { char buf; return recv(socket_r, &buf, 1, 0); } -value lwt_unix_init_notification() { +value lwt_unix_init_notification(intnat domain_id) { SOCKET sockets[2]; switch (notification_mode) { @@ -702,6 +739,7 @@ value lwt_unix_init_notification() { return win_alloc_socket(socket_r); } + #else /* defined(LWT_ON_WINDOWS) */ static void set_close_on_exec(int fd) { @@ -712,47 +750,69 @@ static void set_close_on_exec(int fd) { #if defined(HAVE_EVENTFD) -static int notification_fd; - -static int eventfd_notification_send() { +static int eventfd_notification_send(int domain_id) { uint64_t buf = 1; - return write(notification_fd, (char *)&buf, 8); + if (domain_id < 0 || domain_id >= MAX_DOMAINS || !domain_states_initialized[domain_id]) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = write(state->notification_fd, (char *)&buf, 8); + return result; } -static int eventfd_notification_recv() { +static int eventfd_notification_recv(int domain_id) { uint64_t buf; - return read(notification_fd, (char *)&buf, 8); + if (domain_id < 0 || domain_id >= MAX_DOMAINS || !domain_states_initialized[domain_id]) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = read(state->notification_fd, (char *)&buf, 8); + return result; } #endif /* defined(HAVE_EVENTFD) */ -static int notification_fds[2]; - -static int pipe_notification_send() { +static int pipe_notification_send(int domain_id) { char buf = 0; - return write(notification_fds[1], &buf, 1); + if (domain_id < 0 || domain_id >= MAX_DOMAINS || !domain_states_initialized[domain_id]) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = write(state->notification_fds[1], &buf, 1); + return result; } -static int pipe_notification_recv() { +static int pipe_notification_recv(int domain_id) { char buf; - return read(notification_fds[0], &buf, 1); + if (domain_id < 0 || domain_id >= MAX_DOMAINS || !domain_states_initialized[domain_id]) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = read(state->notification_fds[0], &buf, 1); + return result; } -value lwt_unix_init_notification() { - switch (notification_mode) { +value lwt_unix_init_notification(int domain_id) { + /* Initialize domain state if needed */ + init_domain_notifications(domain_id); + if (domain_id < 0 || domain_id >= MAX_DOMAINS || !domain_states_initialized[domain_id]) { + caml_failwith("invalid domain_id in lwt_unix_init_notification"); + } + struct domain_notification_state *state = &domain_states[domain_id]; + switch (state->notification_mode) { #if defined(HAVE_EVENTFD) case NOTIFICATION_MODE_EVENTFD: - notification_mode = NOTIFICATION_MODE_NONE; - if (close(notification_fd) == -1) uerror("close", Nothing); + state->notification_mode = NOTIFICATION_MODE_NONE; + if (close(state->notification_fd) == -1) uerror("close", Nothing); break; #endif case NOTIFICATION_MODE_PIPE: - notification_mode = NOTIFICATION_MODE_NONE; - if (close(notification_fds[0]) == -1) uerror("close", Nothing); - if (close(notification_fds[1]) == -1) uerror("close", Nothing); + state->notification_mode = NOTIFICATION_MODE_NONE; + if (close(state->notification_fds[0]) == -1) uerror("close", Nothing); + if (close(state->notification_fds[1]) == -1) uerror("close", Nothing); break; case NOTIFICATION_MODE_NOT_INITIALIZED: - notification_mode = NOTIFICATION_MODE_NONE; + state->notification_mode = NOTIFICATION_MODE_NONE; init_notifications(); break; case NOTIFICATION_MODE_NONE: @@ -762,27 +822,32 @@ value lwt_unix_init_notification() { } #if defined(HAVE_EVENTFD) - notification_fd = eventfd(0, 0); - if (notification_fd != -1) { - notification_mode = NOTIFICATION_MODE_EVENTFD; + state->notification_fd = eventfd(0, 0); + if (state->notification_fd != -1) { + state->notification_mode = NOTIFICATION_MODE_EVENTFD; notification_send = eventfd_notification_send; notification_recv = eventfd_notification_recv; - set_close_on_exec(notification_fd); - return Val_int(notification_fd); + set_close_on_exec(state->notification_fd); + return Val_int(state->notification_fd); } #endif - if (pipe(notification_fds) == -1) uerror("pipe", Nothing); - set_close_on_exec(notification_fds[0]); - set_close_on_exec(notification_fds[1]); - notification_mode = NOTIFICATION_MODE_PIPE; + if (pipe(state->notification_fds) == -1) uerror("pipe", Nothing); + set_close_on_exec(state->notification_fds[0]); + set_close_on_exec(state->notification_fds[1]); + state->notification_mode = NOTIFICATION_MODE_PIPE; notification_send = pipe_notification_send; notification_recv = pipe_notification_recv; - return Val_int(notification_fds[0]); + return Val_int(state->notification_fds[0]); } #endif /* defined(LWT_ON_WINDOWS) */ +CAMLprim value lwt_unix_init_notification_stub(value domain_id) { + value res = lwt_unix_init_notification(Long_val(domain_id)); + return res; +} + /* +-----------------------------------------------------------------+ | Signals | +-----------------------------------------------------------------+ */ @@ -797,7 +862,7 @@ static intnat signal_notifications[NSIG]; CAMLextern int caml_convert_signal_number(int); /* Send a notification when a signal is received. */ -static void handle_signal(int signum) { +void handle_signal(int signum) { if (signum >= 0 && signum < NSIG) { intnat id = signal_notifications[signum]; if (id != -1) { @@ -806,7 +871,9 @@ static void handle_signal(int signum) { function. */ signal(signum, handle_signal); #endif - lwt_unix_send_notification(id); + //TODO: domain_self instead of root (0)? caml doesn't expose + //caml_ml_domain_id in domain.h :( + lwt_unix_send_notification(0, id); } } } @@ -822,7 +889,9 @@ static BOOL WINAPI handle_break(DWORD event) { intnat id = signal_notifications[SIGINT]; if (id == -1 || (event != CTRL_C_EVENT && event != CTRL_BREAK_EVENT)) return FALSE; - lwt_unix_send_notification(id); + //TODO: domain_self instead of root (0)? caml doesn't expose + //caml_ml_domain_id in domain.h :( + lwt_unix_send_notification(0, id); return TRUE; } #endif @@ -909,7 +978,7 @@ CAMLprim value lwt_unix_init_signals(value Unit) { +-----------------------------------------------------------------+ */ /* Execute the given job. */ -static void execute_job(lwt_unix_job job) { +void execute_job(lwt_unix_job job) { DEBUG("executing the job"); lwt_unix_mutex_lock(&job->mutex); @@ -937,7 +1006,7 @@ static void execute_job(lwt_unix_job job) { if (job->fast == 0) { lwt_unix_mutex_unlock(&job->mutex); DEBUG("notifying the main thread"); - lwt_unix_send_notification(job->notification_id); + lwt_unix_send_notification(job->domain_id, job->notification_id); } else { lwt_unix_mutex_unlock(&job->mutex); DEBUG("not notifying the main thread"); @@ -990,7 +1059,7 @@ void initialize_threading() { /* Function executed by threads of the pool. * Note: all signals are masked for this thread. */ -static void *worker_loop(void *data) { +void *worker_loop(void *data) { lwt_unix_job job = (lwt_unix_job)data; /* Execute the initial job if any. */ diff --git a/test/core/test_lwt.ml b/test/core/test_lwt.ml index f22c72233..d33f97725 100644 --- a/test/core/test_lwt.ml +++ b/test/core/test_lwt.ml @@ -47,7 +47,6 @@ let add_loc exn = try raise exn with exn -> exn let suites : Test.suite list = [] - (* Tests for promises created with [Lwt.return], [Lwt.fail], and related functions, as well as state query (hard to test one without the other). These tests use assertions instead of relying on the correctness of a final @@ -2124,6 +2123,7 @@ let both_tests = suite "both" [ state_is Lwt.Sleep p end; + test "pending, fulfilled, then fulfilled" begin fun () -> let p1, r1 = Lwt.wait () in let p = Lwt.both p1 (Lwt.return 2) in @@ -4205,7 +4205,7 @@ let lwt_sequence_tests = suite "add_task_l and add_task_r" [ let suites = suites @ [lwt_sequence_tests] - +(* let pause_tests = suite "pause" [ test "initial state" begin fun () -> Lwt.return (Lwt.paused_count () = 0) @@ -4290,6 +4290,7 @@ let pause_tests = suite "pause" [ end; ] let suites = suites @ [pause_tests] +*) diff --git a/test/multidomain/basic.ml b/test/multidomain/basic.ml new file mode 100644 index 000000000..a8b2edf9b --- /dev/null +++ b/test/multidomain/basic.ml @@ -0,0 +1,55 @@ +open Lwt.Syntax + +(* we don't call run in the root domain so we initialise by hand *) +let () = Lwt_unix.init_domain () + +let p_one, w_one = Lwt.wait () +let v_one = 3 +let p_two, w_two = Lwt.wait () +let v_two = 2 + +let d_mult = Domain.spawn (fun () -> + Lwt_unix.init_domain (); + (* domain one: wait for value from domain two then work and then send a value *) + Lwt_main.run ( + let* () = Lwt_unix.sleep 0.01 in + let* v_two = p_two in +(* Printf.printf "d%d received %d\n" (Domain.self () :> int) v_two; *) + let* () = Lwt_unix.sleep 0.1 in + Lwt.wakeup w_one v_one; +(* Printf.printf "d%d sent %d\n" (Domain.self () :> int) v_one; *) + Lwt.return (v_two * v_one) + ) +) +let d_sum = Domain.spawn (fun () -> + Lwt_unix.init_domain (); + Lwt_main.run ( + let () = + (* concurrent thread within domain "two" send a value and then work and + then wait for a value from domain one *) + Lwt.dont_wait (fun () -> + let* () = Lwt_unix.sleep 0.1 in +(* Printf.printf "d%d slept\n" (Domain.self () :> int); *) + Lwt.wakeup w_two v_two; +(* Printf.printf "d%d sent %d\n" (Domain.self () :> int) v_two; *) + Lwt.return () + ) + (fun _ -> exit 1) + in + let* v_one = p_one in + Lwt.return (v_two + v_one) + ) +) + + +let mult = Domain.join d_mult +let sum = Domain.join d_sum + +let () = + if mult = v_one * v_two && sum = v_one + v_two then begin + Printf.printf "basic: ✓\n"; + exit 0 + end else begin + Printf.printf "basic: ×\n"; + exit 1 + end diff --git a/test/multidomain/domainworkers.ml b/test/multidomain/domainworkers.ml new file mode 100644 index 000000000..1d04da1f6 --- /dev/null +++ b/test/multidomain/domainworkers.ml @@ -0,0 +1,74 @@ +open Lwt.Syntax + +let rec worker recv_task f send_result = + let* task = Lwt_stream.get recv_task in + match task with + | None -> +(* let () = Printf.printf "worker(%d) received interrupt\n" (Domain.self () :> int); flush_all() in *) + send_result None; + Lwt.return () + | Some data -> +(* let () = Printf.printf "worker(%d) received task (%S)\n" (Domain.self () :> int) data; flush_all() in *) + let* result = f data in + send_result (Some result); +(* let () = Printf.printf "worker(%d) sent result (%d)\n" (Domain.self () :> int) result; flush_all() in *) + let* () = Lwt.pause () in + worker recv_task f send_result + +let spawn_domain_worker f = + let recv_task, send_task = Lwt_stream.create () in + let recv_result, send_result = Lwt_stream.create () in + let dw = + Domain.spawn (fun () -> + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + worker recv_task f send_result + ) + ) + in + send_task, dw, recv_result + +let simulate_work data = + let simulated_work_duration = String.length data in + let* () = Lwt_unix.sleep (0.01 *. float_of_int simulated_work_duration) in + Lwt.return (String.length data) + +let input = [""; "adsf"; "lkjh"; "lkjahsdflkjahdlfkjha"; "0"; ""; ""; ""; ""; ""; "adf"; "ASDSKJLHDAS"; "WPOQIEU"; "DSFALKHJ"; ""; ""; ""; ""; "SD"; "SD"; "SAD; SD;SD"; "ad"; "...."] +let expected_result = List.fold_left (fun acc s -> acc + String.length s) 0 input + +let main () = + let send_task1, dw1, recv_result1 = spawn_domain_worker simulate_work in + let send_task2, dw2, recv_result2 = spawn_domain_worker simulate_work in + let l = + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + let () = (* push work *) + List.iteri + (fun idx s -> if idx mod 3 = 0 then send_task1 (Some s) else send_task2 (Some s)) + input + in + send_task1 None; + send_task2 None; + let* lengths1 = Lwt_stream.fold (+) recv_result1 0 + and* lengths2 = Lwt_stream.fold (+) recv_result2 0 + in + Lwt.return (lengths1 + lengths2) + ) + in + let () = Domain.join dw1 in + let () = Domain.join dw2 in + let code = + if l = expected_result then begin + Printf.printf "domain-workers: ✓\n"; + 0 + end else begin + Printf.printf "domain-workers: ×\n"; + 1 + end + in + flush_all (); + exit code + +let () = main () diff --git a/test/multidomain/dune b/test/multidomain/dune new file mode 100644 index 000000000..2cddc5bbf --- /dev/null +++ b/test/multidomain/dune @@ -0,0 +1,3 @@ +(tests + (names basic domainworkers movingpromises) + (libraries lwt lwt.unix)) diff --git a/test/multidomain/movingpromises.ml b/test/multidomain/movingpromises.ml new file mode 100644 index 000000000..34d47f4d2 --- /dev/null +++ b/test/multidomain/movingpromises.ml @@ -0,0 +1,83 @@ +open Lwt.Syntax + +let rec worker ongoing_tasks recv_task f = + let* task = Lwt_stream.get recv_task in + match task with + | None -> +(* let () = Printf.printf "worker(%d) received interrupt\n" (Domain.self () :> int); flush_all() in *) + Lwt.join ongoing_tasks + | Some (_idx, data, resolver) -> + let task = +(* let () = Printf.printf "worker(%d) received task(%d)\n" (Domain.self () :> int) _idx; flush_all() in *) + let* data in +(* let () = Printf.printf "worker(%d) received task(%d) data(%S)\n" (Domain.self () :> int) _idx data; flush_all() in *) + let* result = f data in + Lwt.wakeup resolver result; +(* let () = Printf.printf "worker(%d) sent result(%d) for task(%d)\n" (Domain.self () :> int) result _idx; flush_all() in *) + Lwt.return () + in + let* () = Lwt.pause () in + worker (task :: ongoing_tasks) recv_task f + +let spawn_domain_worker f = + let recv_task, send_task = Lwt_stream.create () in + let dw = + Domain.spawn (fun () -> + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + worker [] recv_task f + ) + ) + in + send_task, dw + +let simulate_work data = + let simulated_work_duration = String.length data in + let* () = Lwt_unix.sleep (0.01 *. float_of_int simulated_work_duration) in + Lwt.return (String.length data) + +let simulate_input data = + let simulated_work_duration = max 1 (10 - String.length data) in + let* () = Lwt_unix.sleep (0.01 *. float_of_int simulated_work_duration) in + Lwt.return data + +let input = [""; "adsf"; "lkjh"; "lkjahsdflkjahdlfkjha"; "0"; ""; ""; ""; ""; ""; "adf"; "ASDSKJLHDAS"; "WPOQIEU"; "DSFALKHJ"; ""; ""; ""; ""; "SD"; "SD"; "SAD; SD;SD"; "ad"; "...."] +let expected_result = input |> List.map String.length |> List.map string_of_int |> String.concat "," + +let main () = + let send_task1, dw1 = spawn_domain_worker simulate_work in + let send_task2, dw2 = spawn_domain_worker simulate_work in + let l = + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + let inputs = List.map simulate_input + [""; "adsf"; "lkjh"; "lkjahsdflkjahdlfkjha"; "0"; ""; ""; ""; ""; ""; "adf"; "ASDSKJLHDAS"; "WPOQIEU"; "DSFALKHJ"; ""; ""; ""; ""; "SD"; "SD"; "SAD; SD;SD"; "ad"; "...."] + in + let* lengths = + Lwt_list.mapi_p + (fun idx s -> + let (p, r) = Lwt.task () in + begin if idx mod 3 = 0 then send_task1 (Some (idx, s, r)) else send_task2 (Some (idx, s, r)) end; + p) + inputs + in + let* () = Lwt.pause () in + send_task1 None; + send_task2 None; + let lengths = lengths |> List.map string_of_int |> String.concat "," in + Lwt.return lengths + ) + in + let () = Domain.join dw1 in + let () = Domain.join dw2 in + if l = expected_result then begin + Printf.printf "moving-promises: ✓\n"; + exit 0 + end else begin + Printf.printf "moving-promises: ×\n"; + exit 1 + end + +let () = main () diff --git a/test/test.ml b/test/test.ml index bc18a36bb..037c1e325 100644 --- a/test/test.ml +++ b/test/test.ml @@ -2,7 +2,6 @@ details, or visit https://github.com/ocsigen/lwt/blob/master/LICENSE.md. *) - type test = { test_name : string; skip_if_this_is_false : unit -> bool; diff --git a/test/unix/main.ml b/test/unix/main.ml index 34d2d4983..7e36f99c3 100644 --- a/test/unix/main.ml +++ b/test/unix/main.ml @@ -1,12 +1,16 @@ (* This file is part of Lwt, released under the MIT license. See LICENSE.md for details, or visit https://github.com/ocsigen/lwt/blob/master/LICENSE.md. *) +let () = Lwt_unix.init_domain () + open Tester let () = Test.concurrent "unix" [ +(* Test_lwt_unix.suite; Test_lwt_io.suite; +*) Test_lwt_io_non_block.suite; Test_lwt_process.suite; Test_lwt_engine.suite; diff --git a/test/unix/test_lwt_bytes.ml b/test/unix/test_lwt_bytes.ml index 6de438b8d..fa6d328da 100644 --- a/test/unix/test_lwt_bytes.ml +++ b/test/unix/test_lwt_bytes.ml @@ -597,23 +597,6 @@ let suite = suite "lwt_bytes" [ Lwt.return check end; - test "read: buffer retention" ~sequential:true begin fun () -> - let buffer = Lwt_bytes.create 3 in - - let read_fd, write_fd = Lwt_unix.pipe ~cloexec:true () in - Lwt_unix.set_blocking read_fd true; - - Lwt_unix.write_string write_fd "foo" 0 3 >>= fun _ -> - - let retained = Lwt_unix.retained buffer in - Lwt_bytes.read read_fd buffer 0 3 >>= fun _ -> - - Lwt_unix.close write_fd >>= fun () -> - Lwt_unix.close read_fd >|= fun () -> - - !retained - end; - test "bytes write" begin fun () -> let test_file = "bytes_io_data_write" in Lwt_unix.openfile test_file [O_RDWR;O_TRUNC; O_CREAT] 0o666 @@ -634,21 +617,6 @@ let suite = suite "lwt_bytes" [ Lwt.return check end; - test "write: buffer retention" ~sequential:true begin fun () -> - let buffer = Lwt_bytes.create 3 in - - let read_fd, write_fd = Lwt_unix.pipe ~cloexec:true () in - Lwt_unix.set_blocking write_fd true; - - let retained = Lwt_unix.retained buffer in - Lwt_bytes.write write_fd buffer 0 3 >>= fun _ -> - - Lwt_unix.close write_fd >>= fun () -> - Lwt_unix.close read_fd >|= fun () -> - - !retained - end; - test "bytes recv" ~only_if:(fun () -> not Sys.win32) begin fun () -> let buf = gen_buf 6 in let server_logic socket = diff --git a/test/unix/test_lwt_unix.ml b/test/unix/test_lwt_unix.ml index a4e747aa3..22edc1ee0 100644 --- a/test/unix/test_lwt_unix.ml +++ b/test/unix/test_lwt_unix.ml @@ -6,6 +6,8 @@ open Test open Lwt.Infix +let domain_root_id = Domain.self () + (* An instance of the tester for the wait/waitpid tests. *) let () = match Sys.argv with @@ -451,12 +453,14 @@ let readv_tests = Lwt_unix.write_string write_fd "foo" 0 3 >>= fun _ -> - let retained = Lwt_unix.retained io_vectors in + let retained = ref true in + Gc.finalise (fun _ -> retained := false) io_vectors; Lwt_unix.readv read_fd io_vectors >>= fun _ -> Lwt_unix.close write_fd >>= fun () -> Lwt_unix.close read_fd >|= fun () -> + Gc.full_major (); !retained end; @@ -619,12 +623,14 @@ let writev_tests = let read_fd, write_fd = Lwt_unix.pipe ~cloexec:true () in Lwt_unix.set_blocking write_fd true; - let retained = Lwt_unix.retained io_vectors in + let retained = ref true in + Gc.finalise (fun _ -> retained := false) io_vectors; Lwt_unix.writev write_fd io_vectors >>= fun _ -> Lwt_unix.close write_fd >>= fun () -> Lwt_unix.close read_fd >|= fun () -> + Gc.full_major (); !retained end; @@ -1054,19 +1060,19 @@ let dir_tests = [ ] let lwt_preemptive_tests = [ - test "run_in_main" begin fun () -> + test "run_in_domain" begin fun () -> let f () = - Lwt_preemptive.run_in_main (fun () -> + Lwt_preemptive.run_in_domain domain_root_id (fun () -> Lwt_unix.sleep 0.01 >>= fun () -> Lwt.return 42) in Lwt_preemptive.detach f () >>= fun x -> Lwt.return (x = 42) end; - test "run_in_main_dont_wait" begin fun () -> + test "run_in_domain_dont_wait" begin fun () -> let p, r = Lwt.wait () in let f () = - Lwt_preemptive.run_in_main_dont_wait + Lwt_preemptive.run_in_domain_dont_wait domain_root_id (fun () -> Lwt.pause () >>= fun () -> Lwt.pause () >>= fun () -> @@ -1078,10 +1084,10 @@ let lwt_preemptive_tests = [ p >>= fun x -> Lwt.return (x = 42) end; - test "run_in_main_dont_wait_fail" begin fun () -> + test "run_in_domain_dont_wait_fail" begin fun () -> let p, r = Lwt.wait () in let f () = - Lwt_preemptive.run_in_main_dont_wait + Lwt_preemptive.run_in_domain_dont_wait domain_root_id (fun () -> Lwt.pause () >>= fun () -> Lwt.pause () >>= fun () -> @@ -1092,10 +1098,10 @@ let lwt_preemptive_tests = [ p >>= fun x -> Lwt.return (x = 45) end; - test "run_in_main_with_dont_wait" begin fun () -> + test "run_in_domain_with_dont_wait" begin fun () -> let p, r = Lwt.wait () in let f () = - Lwt_preemptive.run_in_main (fun () -> + Lwt_preemptive.run_in_domain domain_root_id (fun () -> Lwt.dont_wait (fun () -> Lwt.pause () >>= fun () ->