Skip to content

Dev 4.3.5#187

Closed
tngchien wants to merge 22 commits intoMooreThreads:mainfrom
tngchien:dev_4.3.5
Closed

Dev 4.3.5#187
tngchien wants to merge 22 commits intoMooreThreads:mainfrom
tngchien:dev_4.3.5

Conversation

@tngchien
Copy link
Copy Markdown
Collaborator

No description provided.

timo added 22 commits April 23, 2026 16:40
Replace the in-tree device/stream/event/allocator implementations with the
stable `tensorflow/c/experimental/stream_executor/stream_executor.h` plugin
API so the MUSA extension builds against the official PluggableDevice
contract rather than reaching into TF's private common_runtime headers.

Main changes:

- Drop the bespoke MUSA device stack (musa_allocator, musa_device,
  musa_event, musa_event_mgr, musa_executor, musa_host_allocator,
  musa_platform, musa_stream, musa_utils, pinned_memory_pool).
- Add musa_se_callbacks.{cc,h} implementing SP_Platform / SP_StreamExecutor
  / SP_TimerFns callbacks (create/destroy stream + event, memcpy H2D/D2H/
  D2D, sync variants, allocator, host-memory pinning cache, H2D pinned
  staging pool for pageable sources, and a small event pool to avoid
  per-step driver churn).
- Add musa_resource_mgr.{cc,h} to cache mudnn/mublas handles per
  (device, stream) and share them across kernels; thread-local fast path
  removes the old mutex on every Compute().
- Rewrite device_register.cc to register the plugin through SE_InitPlugin
  with idempotency + constructor/destructor hooks for telemetry.
- Update utils_op.{cc,h} with new helpers (GetMusaStreamByCtx,
  GetHandleByCtx, GetMublasHandleByCtx, MakeMusaMemMaintainer,
  CachedMusaSetDevice). Propagate the new helpers through all kernels
  that used to reach into MusaDevice directly.
- Adjust CMakeLists.txt: drop include paths into TF's common_runtime,
  add -DTF_PLUGGABLE_DEVICE.
- test/musa_test_utils.py now loads the plugin via
  `load_pluggable_device_library` (with an idempotency guard to survive
  repeated test-runner invocations).
- README/README.en: document the new plugin loading flow.
MusaMemcpy{D2H,H2D,D2D} and their Async siblings used to call
`musaGetDevice` unconditionally so the telemetry event could carry the
device ordinal. In release builds the telemetry macro is a no-op, which
makes that driver round-trip pure overhead. Hot kernels (AddN, Concat,
ResourceVariable assign, ...) walk these paths many times per step, so
the cost is observable on CPU-bound inference.

Wrap the prologue in a `MUSA_MEMCPY_TELEMETRY_PROLOGUE` macro that is
compiled out when `MUSA_DISABLE_TRACE_LOGGING` is defined, keeping full
telemetry for tracing builds and removing the overhead everywhere else.
`UseAddBroadcastViewOpt` and `UseAddCustomKernelFastPath` called
`std::getenv` on every invocation. Inference graphs can contain hundreds
of AddV2 nodes, so each Compute() was paying an O(n) walk of the env
table. Cache the result in function-local static booleans evaluated on
the first call.

No behavior change: same env vars, same default (enabled).
The matcher emits two `LOG(INFO)` messages for every successful pattern
match and apply. On a mid-size inference graph this fires hundreds of
times per Optimize() call and clutters the tf.Session startup log
without adding diagnostic value (the pattern manager already reports
aggregate counts). Comment them out so they can be re-enabled with a
one-line revert if a matcher regression needs to be debugged.
When callers import graphs under `with tf.device("/device:MUSA:0")`,
every node inherits an explicit MUSA device placement, including the
small int32 Shape/StridedSlice/Pack chains that feed Reshape/Tile/
BroadcastTo/... through HostMemory-typed inputs. TF's own
`pin_to_host_optimization` refuses to override explicit placements, so
the chain stays on the device and each hop triggers a H2D or D2H plus a
`musaStreamSynchronize` because the consumer reads host memory. On
this workload that pattern accounts for ~3 ms of profiled time per
step (batch=100) concentrated in 3 Pack + 4 StridedSlice nodes.

Add a PinHostComputeToCpu pass that runs after fusion and rewrites
`node.device()` to `/device:CPU:0` for shape-arithmetic subgraphs.
Correctness constraints:

  * Only structurally safe ops are candidates: Shape/ShapeN/Size/Rank
    unconditionally, plus an integer-typed set (Const, Cast,
    StridedSlice, Pack, ConcatV2, Range, Fill, Reshape, arithmetic,
    reductions, comparisons, booleans). FP variants are never moved.
  * Forward check: every non-peer consumer must treat the output as a
    HostMemory input (table is kept in sync with the MUSA kernel
    .HostMemory(...) registrations).
  * Backward check: every non-peer input source must be a peer, a
    Const, or an op whose MUSA kernel already outputs to host memory.
  * The candidate set is iterated to a fixpoint (capped at 32 sweeps)
    before any device() is rewritten, so the decision is all-or-nothing
    per subgraph.

Measured on the prunedGraph inference workload (batch=100,
infer-iters=200, H2D staging pool enabled):

  * pass OFF: ~8.74 ms avg
  * pass ON:  ~8.03 ms avg (-0.71 ms, ~8%)

5000-iter stability run: avg 8.05 ms, P99 8.82 ms, no correctness
issues. Can be disabled via `TF_MUSA_DISABLE_HOST_COMPUTE_PIN=1`.
Restores green runs of apply_adamax / assign_add / interact / layernorm /
resource_apply_nadam / resourcegather / tensorlist_{fromtensor,reserve,stack}
op tests under test_runner.py.

Test-side fixes
---------------
* musa_test_utils: load tf.load_op_library *before*
  load_pluggable_device_library so the returned module actually exposes
  plugin-registered custom ops (MusaInteract, MusaLayerNorm,
  ResourceApplyNadam, ...). The previous order left the follow-up
  load_op_library call returning an empty module because TF tags the .so
  as already-loaded. Expose the module via a new get_musa_ops() helper.
* interact / layernorm / resource_apply_nadam: dispatch through
  get_musa_ops().<op>() rather than tf.raw_ops.<Op>. tf.raw_ops is
  populated from the op registry snapshot taken at TF build time and
  does not include dynamically-registered plugin ops.
* apply_adamax: rewritten from scratch, modelled on
  apply_gradient_descent_op_test.py (tf.Graph + tf.compat.v1.Session per
  case, CPU vs MUSA parity across float32/float16/bfloat16, plus an fp64
  NumPy reference check and a use_locking parity test).

Kernel-side fixes
-----------------
* MusaResourceScatterAddOp: pass indices.NumElements() (not
  indices.shape().dim_sizes().size(), which is just the rank) as the
  leading dim of mScatterND's nd-info. The old formulation silently
  processed only the first index.
* AssignAdd/Sub VariableOp: add bfloat16 registration via a dedicated
  REGISTER_MUSA_ASSIGN_UPDATE_BF16 macro so
  tf.Variable(..., dtype=tf.bfloat16).assign_add(...) routes to MUSA.
* TensorListFromTensor / TensorListStack: fetch the raw stream via
  GetMusaStreamByCtx(ctx) and accept nullptr (== default stream in eager
  mode under the pluggable-device CStream wrapper); switch the internal
  memcpy direction to musaMemcpyDefault so list elements produced by the
  stock CPU-fallback TensorListSetItem (which leaves them in host
  memory) don't fail with musaErrorInvalidValue during Stack.
… destroy

Under long runs with TF_MUSA_H2D_STAGING_THRESHOLD_BYTES and
TF_MUSA_H2D_STAGING_MEMCPY_THREADS>1, inference would silently hang and
dmesg reported "force destroy app memory context". Root cause is a
concurrent-call race in PinnedStagingPool::ParallelMemcpyImpl: the
worker pool keeps a single job slot (job_dst_/job_src_/pending_/
generation_), but Run() releases mu_ between setup and done_cv_.wait so
a second concurrent Run() can clobber the slot before the first run's
workers have even picked up their chunks. The second run's workers then
service the second job and drive pending_ to zero, causing the first
Run() to return with most of its staging buffer uncopied; the stale
bytes are musaMemcpyAsync'd to device, downstream kernels consume the
garbage, and the driver tears the context down -- host threads then
block forever on events that never complete.

Guard Run() with an outer run_mu_ so only one job uses the shared slot
at a time; intra-run chunk parallelism is preserved. Also fix
DrainInFlightLocked to scan the whole deque instead of stopping at the
first unfinished event: in_flight_ mixes events from different streams
and devices, so there is no global completion order and the early break
was delaying pinned buffer reclaim.
Introduce musa_ext/mu/tf_compat.h as the single include point for the
PluggableDevice C API headers (stream_executor.h, tf_status.h). The shim
pins SE_MAJOR=0 with static_assert, validates the struct_size macros we
depend on, and reserves TF_MUSA_HAS_* feature flags for future
version-gated fields. Plugin code now uses "mu/tf_compat.h" uniformly.

setup.py / build.sh switch from a hard REQUIRED_TF_VERSION="2.6.1" check
to a range check MIN_TF_VERSION="2.6" / MAX_TF_VERSION_EXCLUSIVE="2.17"
with RECOMMENDED_TF_VERSION="2.6.1". Installed TF inside this range
builds successfully; outside the range the build aborts with a clear
message. Version strings with non-numeric suffixes (rc/dev) are handled.

docs/tf-compat-matrix.md records the supported range, validation status
table, and the procedure for adding new versions or reacting to future
TF ABI bumps. Both README variants now point to the matrix.

Prerequisite for the caching allocator / VMM / _C pybind changes that
follow in the memory-h2d-research plan (stage C0).
Commit C1 of the memory-h2d-research plan.

- Add HostCachingAllocator (event_pool.h, host_caching_allocator.{h,cc}):
  size-class bucketing (pow2, min 64 KiB), total-cap env knob
  TF_MUSA_HOST_ALLOC_MAX_POOL_MB (default 2 GiB), disable switch
  TF_MUSA_DISABLE_HOST_CACHING, stream-ordered reuse via a shared
  per-device EventPool, stats for later _C bindings (C6).
- Rewrite host_memory_allocate/deallocate in musa_se_callbacks.cc to go
  through the caching allocator with a safe fallback to musaHostAlloc
  when the cache refuses or is disabled.
- Refactor PinnedStagingPool to borrow buffers from HostCachingAllocator
  and use RecordStream for deferred, event-based recycling; drop the
  duplicate in-flight queue and size-class code. Legacy
  TF_MUSA_H2D_STAGING_MAX_POOL_MB now emits a one-shot deprecation.
- Add benchmark/bench_host_alloc.py TF-eager A/B harness; measured
  1.16x speedup cached-vs-fresh on 64 KiB x 3000 iters. A sharper
  driver-level bench lands with the _C pybind (C6).

ABI and TF integration unchanged.
Commit C2 of the memory-h2d-research plan. The plugin is now shipped as
three cooperating shared objects so the host caching allocator (and the
device caching allocator that will land in C3) can be observed from both
TF's PluggableDevice callbacks and a future Python extension through
exactly one set of singletons.

- CMakeLists.txt: add `musa_core` (SHARED, owns singleton TUs:
  event_pool.cc + host_caching_allocator.cc), rewire `musa_plugin` to
  link against it, and introduce `tensorflow_musa__C` (MODULE) producing
  `_C.<pyext>.so`. Plugin + _C both carry `$ORIGIN` RPATH so they pick
  up the colocated libmusa_core.so.
- event_pool.cc: move `EventPool::Instance()` out-of-line into the core
  TU so the static local exists in exactly one DSO, regardless of how
  many other libraries include event_pool.h.
- host_caching_allocator.cc, event_pool.cc: convert both singletons to
  the leaked-new-pointer idiom; static-local destruction at process
  exit ran after libmusart/libtensorflow_framework were unloaded in the
  split layout, producing SIGSEGV at teardown. Leaking costs nothing.
- musa_ext/python/_C.cpp: minimal CPython extension (no pybind11 yet,
  that lands with C6). Exposes `_is_loaded()` and `_host_allocator_stats()`;
  the latter is used by the new layout test to validate the cross-library
  singleton is actually shared.
- setup.py: copy all three artifacts into the package payload.
- test/test_wheel_layout.py: new pytest covering presence, RPATH, NEEDED
  on libmusa_core.so, and the _C importability probe.

Manual verification: TF_MUSA_H2D_STAGING_THRESHOLD_BYTES=4096 run shows
_C.host_allocator_stats() reports `alloc_requests=20, cache_hits=19,
cache_misses=1` after a 20-iteration H2D loop driven through
libmusa_plugin.so, confirming one allocator instance across both
consumers; process exits cleanly.
…om BFC

Stands up a per-device block/pool caching allocator for MUSA device
memory and wires it in front of the TF PluggableDevice bridge so that
most allocations are served from a reusable free list instead of
hitting `musaMalloc` on every request. This is the C3 step in the
memory/H2D plan and establishes the foundation for stream-aware
reuse, VMM expandable segments and the Python snapshot API in later
commits.

What lands in this commit:

* `mu/device/caching_allocator.{h,cc}`: a mutex-guarded, address-ordered
  block pool per ordinal with splitting on alloc and neighbour merging
  on free. Exposes `Allocate`, `Free`, `EmptyCache`, `GetStats`, and
  `ResetPeakStats`. Scope is deliberately limited for the MVP: no
  stream-ordered reuse yet, no VMM, no OOM observer — those gates
  arrive with C4/C5.

* `mu/device_register.cc`: `platform->use_bfc_allocator` is now driven
  by the `TF_MUSA_DEVICE_ALLOCATOR` env var. `caching` (default) turns
  BFC off and lets our allocator serve the device heap; `passthrough`
  reinstates BFC and routes our callbacks straight to `musaMalloc`,
  giving us an easy A/B toggle for regression bisection.

* `mu/device/musa_se_callbacks.cc`: the `allocate`/`deallocate`
  callbacks dispatch to the caching allocator or fall back to raw
  `musaMalloc/musaFree` based on the same env. The existing async
  allocator path is preserved untouched.

* `python/_C.cpp`: three new entry points on the `_C` extension —
  `_device_allocator_stats(ordinal=0)`, `_device_allocator_backend()`
  and `_device_empty_cache(ordinal=0)`. Because `_C.so` links
  `libmusa_core.so`, this confirms the cross-library singleton wired
  up in C2 survives a real device workload, not just the host path.

* `test/test_device_caching_allocator.py`: pytest suite that pins the
  public stats schema, checks the backend string, and — in an isolated
  subprocess to sidestep TF 2.6's `load_pluggable_device_library`
  re-entrancy abort under pytest — drives a 30-step eager matmul loop
  and asserts >=90% cache-hit ratio, non-trivial splits/merges, a
  monotonic peak counter, and at least one driver segment.

Measured end-to-end on a 256x256 matmul loop: 89/90 requests served
from cache (98.9% hit ratio), 32 splits, 29 merges, 1 driver segment.
The `passthrough` mode was re-tested and still routes through BFC
with identical numerical results, so rollback is one env-var away.
Rounds out the device caching allocator so the Python layer (next
commit) has all the data it needs and operators get useful diagnostics
on failure. This is the C4 step in the memory/H2D plan.

Four additions:

* Segment snapshot. The allocator now tracks segment heads in a
  dedicated set and exposes `GetSegmentSnapshot()` returning one
  `DeviceSegmentInfo` per live segment (address, size, in-use bytes,
  block counts, largest free block). This is the primitive the future
  `memory_snapshot()` Python tool will build on.

* Memory fraction / limit. `SetMemoryFraction(f)` resolves
  `f * musaMemGetInfo.total` into a hard byte cap on what the
  allocator is willing to obtain from the driver; `SetMemoryLimitBytes`
  sets the same cap explicitly. Requests whose cache-miss path would
  cross the cap fail as OOM without touching musaMalloc, matching the
  `torch.musa.set_per_process_memory_fraction` contract. The cap is a
  hard ceiling only on NEW driver allocations: already-live blocks stay
  put even if the limit is later lowered, so a misconfigured cap can
  never tear down an in-flight tensor.

* OOM diagnostic. Every allocation failure (limit-would-be-exceeded
  or `musaMalloc` returned error) builds a multi-line message snapshot
  with device, requested bytes, reserved / in-use / peak, segments,
  cache hit/miss counters, the configured limit, and driver
  free/total. The message is stashed on the allocator for programmatic
  access via `GetLastOomMessage()`. Setting `TF_MUSA_ALLOC_VERBOSE_OOM=1`
  additionally mirrors the message to stderr every time — this is the
  equivalent of `torch_musa`'s `TORCH_MUSA_OOM_VERBOSE`.

* SP_StreamExecutor.get_allocator_stats wiring. In caching mode the
  SE callback now fills `SP_AllocatorStats` from
  `DeviceCachingAllocator::GetStats()` so tf.config.experimental
  get_memory_info sees real numbers; passthrough mode still returns
  false and lets TF's BFC populate its own stats.

Python surface on `tensorflow_musa._C`:

  * `_device_segment_snapshot(ordinal=0)` → list[dict]
  * `_device_set_memory_fraction(fraction, ordinal=0)` → int bytes
  * `_device_set_memory_limit_bytes(bytes, ordinal=0)` → None
  * `_device_reset_peak_stats(ordinal=0)` → None
  * `_device_last_oom_message(ordinal=0)` → str
  * `_device_memory_usage(ordinal=0)` → (free, total)
  * `_device_allocator_stats(ordinal=0)` now returns two new keys,
    `limit_bytes` and `total_device_bytes`, alongside the pre-C3 stats

Validated end-to-end with TF_MUSA_ALLOC_VERBOSE_OOM=1 and a deliberately
tiny 1 MiB limit: the allocator refuses a 64 MiB request, sets
`oom_events=1`, mirrors the diagnostic to stderr, and TF surfaces the
failure as `ResourceExhaustedError` to the caller. Test suite grew
from 9 to 11 checks covering the new schema keys, segment snapshot
invariants, and the fraction/limit round trip.
Introduces the VMM-backed expandable-segments allocation path behind
the TF_MUSA_ALLOC_CONF=expandable_segments:true env flag. Opt-in and
gated on runtime driver capability, so the existing musaMalloc path
remains the default.

  * driver_api.{h,cc}: dlopen libmusa.so + dlsym for the VMM surface
    (muMemAddressReserve/Create/Map/SetAccess/Unmap/Release/
    AddressFree/GetAllocationGranularity) plus muDeviceGetAttribute
    and muGetErrorString. Missing symbols disable VMM gracefully.
    Per-device capability probe forces driver init via musaSetDevice
    before querying MU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_
    SUPPORTED, caching only positive results.

  * allocator_config.{h,cc}: parser for TF_MUSA_ALLOC_CONF with the
    torch_musa key:value syntax. Only expandable_segments is honored
    in this commit; max_split_size_mb / roundup_power2_divisions /
    garbage_collection_threshold are parsed and exposed for forward
    compatibility.

  * expandable_segment.{h,cc}: per-segment RAII wrapper that rounds
    the request to driver granularity, reserves VA, creates a pinned
    handle, maps, and grants R/W access on the owning device. Dtor
    unwinds in reverse. Driver failures surface as a null factory
    return; callers (the caching allocator) then fall back to
    musaMalloc transparently.

  * caching_allocator: Block gains an `expandable` pointer;
    AllocateSegmentUnlocked routes to ExpandableSegment::Create
    when the env flag + driver capability agree, and EmptyCache
    dispatches release to the VMM unmap path or musaFree based on
    that pointer. Segment snapshots now report is_expandable
    accurately.

  * _C Python surface: _vmm_available / _vmm_supported /
    _vmm_granularity / _allocator_config diagnostics, no new
    allocation entry points (those land with the pybind C6 surface).

  * test/test_vmm_expandable_segments.py: schema + probe contracts
    plus a subprocess-isolated end-to-end that asserts every live
    segment is VMM-backed when the env flag is on and the device
    supports VMM, and that the flag is a no-op when off.
Adds a narrow Python surface (tensorflow_musa.memory / .device) over
the _C extension:

- python/_ext.py: package-aware, cached lazy loader for _C, resolving
  both installed-wheel and in-tree build/ layouts so the wrappers work
  under pytest, pip install, and ad-hoc source-tree runs.
- python/memory.py: empty_cache / memory_allocated / memory_reserved /
  max_memory_allocated / reset_peak_memory_stats / memory_stats /
  set_per_process_memory_fraction / mem_get_info / get_allocator_backend.
  Matches the shape of torch.musa.memory.* where it doesn't conflict
  with TF idioms; deliberately omits Stream/Event/MemPool per plan §6.3.
- python/device.py: device_count / current_device / get_device_name /
  is_available. Device discovery goes through TF so MUSA_VISIBLE_DEVICES
  and plugin visibility stay consistent with what TF actually schedules.
- python/__init__.py: re-exports the new submodules without forcing _C
  to load at import time (the loader is lazy), keeping the package
  import cheap and usable on CPU-only hosts.
- test/test_python_memory_api.py: shape tests that work without MUSA
  hardware plus an end-to-end subprocess test that drives TF traffic
  and observes it through the new wrappers.
Adds the optional diagnostic tier called out in plan §5.6 / §6.3:

- python/snapshot.py
  - memory_snapshot(device=None) -> dict: composes stats + segments +
    driver free/total + parsed TF_MUSA_ALLOC_CONF + VMM probes +
    last-OOM message + (when active) sampled history, all from
    existing _C entry points. No new native ABI.
  - _dump_snapshot(path, device=None): atomic JSON write via
    <path>.tmp + os.replace so file-watchers see either the old or
    the new file, never a truncated one.
  - _record_memory_history(enabled, max_entries=1024, interval_ms=50):
    daemon-thread ring-buffer sampler over _device_allocator_stats.
    Sampling (vs alloc/free hooks) keeps the native contract flat and
    the hot path untouched; the note at §4.2 S5 explicitly allows
    "size+stream+event" snapshots without symbolicated stacks for
    the first drop.
- python/memory.py: re-exports the three helpers so the plan's public
  API lives in one place.
- test/test_memory_snapshot.py: shape checks, JSON round-trip,
  ring-buffer trim, monotonic timestamps, and confirmation that the
  sampler is off by default.
Introduces the three benchmark scripts called out in plan §5.5
(task 7-a) plus the CI regression backstop (task 7-b):

- benchmark/bench_h2d.py: pure H2D / D2H throughput sweep across a
  size range (4K..16M by default). Uses subprocess isolation so every
  run starts with a cold allocator; reports mean / p50 / p95 / GB/s.
- benchmark/bench_resnet.py: 100-step synthetic training loop with
  ResNet50 when tf.keras is healthy, falling back to a hand-rolled
  conv stack when keras's dtensor import is broken (common on TF 2.6
  with a stand-alone keras wheel). Hand-rolled SGD avoids pulling in
  tf.optimizers; conv shapes chosen so MUSA's symmetric-padding
  requirement is met. Reports mean / p50 / p95 step time plus peak
  allocator footprint.
- benchmark/bench_alloc_churn.py: mixed-size allocation storm that
  exercises the caching allocator's split/merge paths. Asserts four
  invariants at the end (no leaks, no spurious OOM, empty_cache
  releases >=95% of reserved, cache_hits > 0) and exits non-zero on
  any violation.
- test/test_bench_alloc_churn.py: wraps the churn bench in a cheap
  pytest (iters=20, batch=8) so CI can run it as part of the regular
  suite and fail fast when the allocator regresses.
First deliverable in the plan's "optional" phase (§3.8, §4.3 C3):
multi-GPU peer-access capability probes and enable control, with a
process-wide cache so repeat queries stay cheap.

- musa_ext/mu/device/peer_access.{h,cc}: leaked-singleton PeerAccess
  that wraps musaDeviceCanAccessPeer + musaDeviceEnablePeerAccess.
  Saves/restores the caller's current device around the enable call
  so the probe is side-effect free; squashes ErrorPeerAccessAlreadyEnabled
  to success for idempotency; exposes a Snapshot() view of observed
  pairs for diagnostics.
- musa_ext/python/_C.cpp: four new entries (_peer_device_count,
  _peer_can_access, _peer_enable_access, _peer_access_snapshot).
- python/device.py: user-facing can_access_peer / enable_peer_access /
  peer_access_snapshot. Documents that the plugin's memcpy_dtod path
  does NOT yet auto-dispatch peer copies; these entry points give
  multi-GPU users a way to enable access manually (or drive their own
  musaMemcpyPeerAsync call).
- test/test_peer_access.py: validates self-access, argument handling,
  out-of-range behavior, and — when >= 2 MUSA devices are present —
  the cached lookup and idempotent enable paths.
- CMakeLists.txt: registers peer_access.cc in MUSA_CORE_SOURCES so
  both libmusa_core.so and _C.so see the same singleton, matching the
  plan §6.1 "shared state across libraries" architecture.

This commit is intentionally decoupled from the main allocator path;
follow-ups can add peer-aware memcpy dispatch and extend VMM
expandable segments with peer muMemSetAccess lists, both listed in
plan §3.8 but gated on real multi-GPU demand.
`DeviceCachingAllocator::Allocate()` used to release the allocator
mutex between two critical sections on the cache-hit path: the first
erased the block from `free_blocks` and called `MaybeSplitLocked`, and
the second re-acquired the lock to set `block->allocated = true` and
emplace the block into `active_blocks`. In the gap the block was
still linked into its segment's `prev`/`next` chain but looked free
(`allocated == false`) and lived in neither map. A concurrent `Free()`
on an address-adjacent block would then enter `MergeNeighborsLocked`,
see `!block->allocated`, absorb our just-taken block and `delete` it,
leaving a dangling pointer in `free_blocks` and causing the next
cache-hit lookup to return a `Block*` whose memory had been recycled.

The concrete crash signatures on a multi-threaded TF inference run:

  free(): double free detected in tcache 2
  Check failed: IsAligned() ptr = 0x<non-16B-aligned>

Fix: fold the "commit to active" step (`block->allocated = true`,
`active_blocks.emplace`, in_use / peak accounting) into the first
critical section for both the cache-hit and cache-miss paths, so a
concurrent merge can never observe this block as mergeable. The
driver call (`musaMalloc` / VMM reserve) still happens with the lock
released, which was the only reason the second critical section
existed in the first place.

Adds `test_concurrent_allocate_free_does_not_corrupt_blocks`, a
subprocess-isolated stress test that drives four host threads through
differently-shaped matmuls against the same per-device allocator and
asserts the workload runs to completion with `in_use_bytes == 0` at
the end.
Prepare the repo for open-source release:

* README (zh + en) slimmed to a professional overview: condense the
  feature list, drop internal design references and verbose debugging
  sections, and point memory tuning at the new public env-var doc.
* `docs/` is now user-facing only and contains just
  `environment-variables.md`, a pure reference for every `TF_MUSA_*`
  knob grouped by subsystem (device allocator, host pinned, H2D
  staging, auto-pin, event pool, graph optimizer).
* Internal design / debugging / compatibility docs move under
  `internal-docs/`:
    - `architecture-and-memory.md` (new): end-to-end architecture and
      memory scheduling reference (retained for internal use).
    - `memory-optimization.md`: history of the torch_musa-inspired
      memory work, with the env-var table aligned to current defaults.
    - `DEBUG_GUIDE.md`, `tf-compat-matrix.md`: moved with git history.
    - `README.md`: explains the split and how to exclude the
      directory from open-source releases.

No code changes; wheel / build layout is unaffected.
`./build.sh debug` used to flip on a musaEvent-based per-kernel timing
scope gated behind the `MUSA_KERNEL_DEBUG` compile flag. The
implementation had drifted out of sync with the rest of the plugin, its
env-var surface (`MUSA_TIMING_KERNEL_LEVEL` / `_NAME` / `_STATS`) was
never finished, and it duplicated device-side timing that TF profiler
already exposes. Remove the feature outright:

* `musa_ext/utils/logging.h`: delete the ~760 LOC timing scope
  (`KernelTimingScope`, `KernelTimingConfig`, `KernelTimingStatsRegistry`
  and the musaEvent lifecycle helpers) that lived under
  `#ifdef MUSA_KERNEL_DEBUG`. The existing kernel call sites
  (`MUSA_KERNEL_TIMING_GUARD(ctx)`, `MUSA_KERNEL_TRACE_START/END(...)`)
  are preserved as unconditional no-op macros so the 50+ kernel files
  that use them keep compiling without churn, and we retain a clean
  hook point if per-kernel timing is reintroduced later.
* `CMakeLists.txt`: drop the `MUSA_KERNEL_DEBUG` CMake option, the
  `if(MUSA_KERNEL_DEBUG)` flag branch (including the stale
  `-DMUSA_KERNEL_DEBUG` define), and the matching status line.
* `setup.py`: remove `-DMUSA_KERNEL_DEBUG=OFF` from the CMake invocation.
* `build.sh`: stop threading `MUSA_KERNEL_DEBUG` through to CMake and
  reword the `debug` build mode to reflect what it actually does now —
  `CMAKE_BUILD_TYPE=Debug`, i.e. unoptimized with `-g -O0` for gdb.
* `README.md` / `README.en.md`: update the Debug-mode row accordingly.
* `internal-docs/DEBUG_GUIDE.md`: drop the kernel-timing and memory-
  coloring sections (the latter was doc-only and never had backing
  code), prune `MUSA_TIMING_KERNEL_*` from the env-var table, and
  redirect the performance-troubleshooting checklist at TF profiler
  plus the public `docs/environment-variables.md` reference.
@tngchien tngchien closed this Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant