Closed
Conversation
added 22 commits
April 23, 2026 16:40
Replace the in-tree device/stream/event/allocator implementations with the
stable `tensorflow/c/experimental/stream_executor/stream_executor.h` plugin
API so the MUSA extension builds against the official PluggableDevice
contract rather than reaching into TF's private common_runtime headers.
Main changes:
- Drop the bespoke MUSA device stack (musa_allocator, musa_device,
musa_event, musa_event_mgr, musa_executor, musa_host_allocator,
musa_platform, musa_stream, musa_utils, pinned_memory_pool).
- Add musa_se_callbacks.{cc,h} implementing SP_Platform / SP_StreamExecutor
/ SP_TimerFns callbacks (create/destroy stream + event, memcpy H2D/D2H/
D2D, sync variants, allocator, host-memory pinning cache, H2D pinned
staging pool for pageable sources, and a small event pool to avoid
per-step driver churn).
- Add musa_resource_mgr.{cc,h} to cache mudnn/mublas handles per
(device, stream) and share them across kernels; thread-local fast path
removes the old mutex on every Compute().
- Rewrite device_register.cc to register the plugin through SE_InitPlugin
with idempotency + constructor/destructor hooks for telemetry.
- Update utils_op.{cc,h} with new helpers (GetMusaStreamByCtx,
GetHandleByCtx, GetMublasHandleByCtx, MakeMusaMemMaintainer,
CachedMusaSetDevice). Propagate the new helpers through all kernels
that used to reach into MusaDevice directly.
- Adjust CMakeLists.txt: drop include paths into TF's common_runtime,
add -DTF_PLUGGABLE_DEVICE.
- test/musa_test_utils.py now loads the plugin via
`load_pluggable_device_library` (with an idempotency guard to survive
repeated test-runner invocations).
- README/README.en: document the new plugin loading flow.
MusaMemcpy{D2H,H2D,D2D} and their Async siblings used to call
`musaGetDevice` unconditionally so the telemetry event could carry the
device ordinal. In release builds the telemetry macro is a no-op, which
makes that driver round-trip pure overhead. Hot kernels (AddN, Concat,
ResourceVariable assign, ...) walk these paths many times per step, so
the cost is observable on CPU-bound inference.
Wrap the prologue in a `MUSA_MEMCPY_TELEMETRY_PROLOGUE` macro that is
compiled out when `MUSA_DISABLE_TRACE_LOGGING` is defined, keeping full
telemetry for tracing builds and removing the overhead everywhere else.
`UseAddBroadcastViewOpt` and `UseAddCustomKernelFastPath` called `std::getenv` on every invocation. Inference graphs can contain hundreds of AddV2 nodes, so each Compute() was paying an O(n) walk of the env table. Cache the result in function-local static booleans evaluated on the first call. No behavior change: same env vars, same default (enabled).
The matcher emits two `LOG(INFO)` messages for every successful pattern match and apply. On a mid-size inference graph this fires hundreds of times per Optimize() call and clutters the tf.Session startup log without adding diagnostic value (the pattern manager already reports aggregate counts). Comment them out so they can be re-enabled with a one-line revert if a matcher regression needs to be debugged.
When callers import graphs under `with tf.device("/device:MUSA:0")`,
every node inherits an explicit MUSA device placement, including the
small int32 Shape/StridedSlice/Pack chains that feed Reshape/Tile/
BroadcastTo/... through HostMemory-typed inputs. TF's own
`pin_to_host_optimization` refuses to override explicit placements, so
the chain stays on the device and each hop triggers a H2D or D2H plus a
`musaStreamSynchronize` because the consumer reads host memory. On
this workload that pattern accounts for ~3 ms of profiled time per
step (batch=100) concentrated in 3 Pack + 4 StridedSlice nodes.
Add a PinHostComputeToCpu pass that runs after fusion and rewrites
`node.device()` to `/device:CPU:0` for shape-arithmetic subgraphs.
Correctness constraints:
* Only structurally safe ops are candidates: Shape/ShapeN/Size/Rank
unconditionally, plus an integer-typed set (Const, Cast,
StridedSlice, Pack, ConcatV2, Range, Fill, Reshape, arithmetic,
reductions, comparisons, booleans). FP variants are never moved.
* Forward check: every non-peer consumer must treat the output as a
HostMemory input (table is kept in sync with the MUSA kernel
.HostMemory(...) registrations).
* Backward check: every non-peer input source must be a peer, a
Const, or an op whose MUSA kernel already outputs to host memory.
* The candidate set is iterated to a fixpoint (capped at 32 sweeps)
before any device() is rewritten, so the decision is all-or-nothing
per subgraph.
Measured on the prunedGraph inference workload (batch=100,
infer-iters=200, H2D staging pool enabled):
* pass OFF: ~8.74 ms avg
* pass ON: ~8.03 ms avg (-0.71 ms, ~8%)
5000-iter stability run: avg 8.05 ms, P99 8.82 ms, no correctness
issues. Can be disabled via `TF_MUSA_DISABLE_HOST_COMPUTE_PIN=1`.
Restores green runs of apply_adamax / assign_add / interact / layernorm /
resource_apply_nadam / resourcegather / tensorlist_{fromtensor,reserve,stack}
op tests under test_runner.py.
Test-side fixes
---------------
* musa_test_utils: load tf.load_op_library *before*
load_pluggable_device_library so the returned module actually exposes
plugin-registered custom ops (MusaInteract, MusaLayerNorm,
ResourceApplyNadam, ...). The previous order left the follow-up
load_op_library call returning an empty module because TF tags the .so
as already-loaded. Expose the module via a new get_musa_ops() helper.
* interact / layernorm / resource_apply_nadam: dispatch through
get_musa_ops().<op>() rather than tf.raw_ops.<Op>. tf.raw_ops is
populated from the op registry snapshot taken at TF build time and
does not include dynamically-registered plugin ops.
* apply_adamax: rewritten from scratch, modelled on
apply_gradient_descent_op_test.py (tf.Graph + tf.compat.v1.Session per
case, CPU vs MUSA parity across float32/float16/bfloat16, plus an fp64
NumPy reference check and a use_locking parity test).
Kernel-side fixes
-----------------
* MusaResourceScatterAddOp: pass indices.NumElements() (not
indices.shape().dim_sizes().size(), which is just the rank) as the
leading dim of mScatterND's nd-info. The old formulation silently
processed only the first index.
* AssignAdd/Sub VariableOp: add bfloat16 registration via a dedicated
REGISTER_MUSA_ASSIGN_UPDATE_BF16 macro so
tf.Variable(..., dtype=tf.bfloat16).assign_add(...) routes to MUSA.
* TensorListFromTensor / TensorListStack: fetch the raw stream via
GetMusaStreamByCtx(ctx) and accept nullptr (== default stream in eager
mode under the pluggable-device CStream wrapper); switch the internal
memcpy direction to musaMemcpyDefault so list elements produced by the
stock CPU-fallback TensorListSetItem (which leaves them in host
memory) don't fail with musaErrorInvalidValue during Stack.
… destroy Under long runs with TF_MUSA_H2D_STAGING_THRESHOLD_BYTES and TF_MUSA_H2D_STAGING_MEMCPY_THREADS>1, inference would silently hang and dmesg reported "force destroy app memory context". Root cause is a concurrent-call race in PinnedStagingPool::ParallelMemcpyImpl: the worker pool keeps a single job slot (job_dst_/job_src_/pending_/ generation_), but Run() releases mu_ between setup and done_cv_.wait so a second concurrent Run() can clobber the slot before the first run's workers have even picked up their chunks. The second run's workers then service the second job and drive pending_ to zero, causing the first Run() to return with most of its staging buffer uncopied; the stale bytes are musaMemcpyAsync'd to device, downstream kernels consume the garbage, and the driver tears the context down -- host threads then block forever on events that never complete. Guard Run() with an outer run_mu_ so only one job uses the shared slot at a time; intra-run chunk parallelism is preserved. Also fix DrainInFlightLocked to scan the whole deque instead of stopping at the first unfinished event: in_flight_ mixes events from different streams and devices, so there is no global completion order and the early break was delaying pinned buffer reclaim.
Introduce musa_ext/mu/tf_compat.h as the single include point for the PluggableDevice C API headers (stream_executor.h, tf_status.h). The shim pins SE_MAJOR=0 with static_assert, validates the struct_size macros we depend on, and reserves TF_MUSA_HAS_* feature flags for future version-gated fields. Plugin code now uses "mu/tf_compat.h" uniformly. setup.py / build.sh switch from a hard REQUIRED_TF_VERSION="2.6.1" check to a range check MIN_TF_VERSION="2.6" / MAX_TF_VERSION_EXCLUSIVE="2.17" with RECOMMENDED_TF_VERSION="2.6.1". Installed TF inside this range builds successfully; outside the range the build aborts with a clear message. Version strings with non-numeric suffixes (rc/dev) are handled. docs/tf-compat-matrix.md records the supported range, validation status table, and the procedure for adding new versions or reacting to future TF ABI bumps. Both README variants now point to the matrix. Prerequisite for the caching allocator / VMM / _C pybind changes that follow in the memory-h2d-research plan (stage C0).
Commit C1 of the memory-h2d-research plan.
- Add HostCachingAllocator (event_pool.h, host_caching_allocator.{h,cc}):
size-class bucketing (pow2, min 64 KiB), total-cap env knob
TF_MUSA_HOST_ALLOC_MAX_POOL_MB (default 2 GiB), disable switch
TF_MUSA_DISABLE_HOST_CACHING, stream-ordered reuse via a shared
per-device EventPool, stats for later _C bindings (C6).
- Rewrite host_memory_allocate/deallocate in musa_se_callbacks.cc to go
through the caching allocator with a safe fallback to musaHostAlloc
when the cache refuses or is disabled.
- Refactor PinnedStagingPool to borrow buffers from HostCachingAllocator
and use RecordStream for deferred, event-based recycling; drop the
duplicate in-flight queue and size-class code. Legacy
TF_MUSA_H2D_STAGING_MAX_POOL_MB now emits a one-shot deprecation.
- Add benchmark/bench_host_alloc.py TF-eager A/B harness; measured
1.16x speedup cached-vs-fresh on 64 KiB x 3000 iters. A sharper
driver-level bench lands with the _C pybind (C6).
ABI and TF integration unchanged.
Commit C2 of the memory-h2d-research plan. The plugin is now shipped as three cooperating shared objects so the host caching allocator (and the device caching allocator that will land in C3) can be observed from both TF's PluggableDevice callbacks and a future Python extension through exactly one set of singletons. - CMakeLists.txt: add `musa_core` (SHARED, owns singleton TUs: event_pool.cc + host_caching_allocator.cc), rewire `musa_plugin` to link against it, and introduce `tensorflow_musa__C` (MODULE) producing `_C.<pyext>.so`. Plugin + _C both carry `$ORIGIN` RPATH so they pick up the colocated libmusa_core.so. - event_pool.cc: move `EventPool::Instance()` out-of-line into the core TU so the static local exists in exactly one DSO, regardless of how many other libraries include event_pool.h. - host_caching_allocator.cc, event_pool.cc: convert both singletons to the leaked-new-pointer idiom; static-local destruction at process exit ran after libmusart/libtensorflow_framework were unloaded in the split layout, producing SIGSEGV at teardown. Leaking costs nothing. - musa_ext/python/_C.cpp: minimal CPython extension (no pybind11 yet, that lands with C6). Exposes `_is_loaded()` and `_host_allocator_stats()`; the latter is used by the new layout test to validate the cross-library singleton is actually shared. - setup.py: copy all three artifacts into the package payload. - test/test_wheel_layout.py: new pytest covering presence, RPATH, NEEDED on libmusa_core.so, and the _C importability probe. Manual verification: TF_MUSA_H2D_STAGING_THRESHOLD_BYTES=4096 run shows _C.host_allocator_stats() reports `alloc_requests=20, cache_hits=19, cache_misses=1` after a 20-iteration H2D loop driven through libmusa_plugin.so, confirming one allocator instance across both consumers; process exits cleanly.
…om BFC
Stands up a per-device block/pool caching allocator for MUSA device
memory and wires it in front of the TF PluggableDevice bridge so that
most allocations are served from a reusable free list instead of
hitting `musaMalloc` on every request. This is the C3 step in the
memory/H2D plan and establishes the foundation for stream-aware
reuse, VMM expandable segments and the Python snapshot API in later
commits.
What lands in this commit:
* `mu/device/caching_allocator.{h,cc}`: a mutex-guarded, address-ordered
block pool per ordinal with splitting on alloc and neighbour merging
on free. Exposes `Allocate`, `Free`, `EmptyCache`, `GetStats`, and
`ResetPeakStats`. Scope is deliberately limited for the MVP: no
stream-ordered reuse yet, no VMM, no OOM observer — those gates
arrive with C4/C5.
* `mu/device_register.cc`: `platform->use_bfc_allocator` is now driven
by the `TF_MUSA_DEVICE_ALLOCATOR` env var. `caching` (default) turns
BFC off and lets our allocator serve the device heap; `passthrough`
reinstates BFC and routes our callbacks straight to `musaMalloc`,
giving us an easy A/B toggle for regression bisection.
* `mu/device/musa_se_callbacks.cc`: the `allocate`/`deallocate`
callbacks dispatch to the caching allocator or fall back to raw
`musaMalloc/musaFree` based on the same env. The existing async
allocator path is preserved untouched.
* `python/_C.cpp`: three new entry points on the `_C` extension —
`_device_allocator_stats(ordinal=0)`, `_device_allocator_backend()`
and `_device_empty_cache(ordinal=0)`. Because `_C.so` links
`libmusa_core.so`, this confirms the cross-library singleton wired
up in C2 survives a real device workload, not just the host path.
* `test/test_device_caching_allocator.py`: pytest suite that pins the
public stats schema, checks the backend string, and — in an isolated
subprocess to sidestep TF 2.6's `load_pluggable_device_library`
re-entrancy abort under pytest — drives a 30-step eager matmul loop
and asserts >=90% cache-hit ratio, non-trivial splits/merges, a
monotonic peak counter, and at least one driver segment.
Measured end-to-end on a 256x256 matmul loop: 89/90 requests served
from cache (98.9% hit ratio), 32 splits, 29 merges, 1 driver segment.
The `passthrough` mode was re-tested and still routes through BFC
with identical numerical results, so rollback is one env-var away.
Rounds out the device caching allocator so the Python layer (next
commit) has all the data it needs and operators get useful diagnostics
on failure. This is the C4 step in the memory/H2D plan.
Four additions:
* Segment snapshot. The allocator now tracks segment heads in a
dedicated set and exposes `GetSegmentSnapshot()` returning one
`DeviceSegmentInfo` per live segment (address, size, in-use bytes,
block counts, largest free block). This is the primitive the future
`memory_snapshot()` Python tool will build on.
* Memory fraction / limit. `SetMemoryFraction(f)` resolves
`f * musaMemGetInfo.total` into a hard byte cap on what the
allocator is willing to obtain from the driver; `SetMemoryLimitBytes`
sets the same cap explicitly. Requests whose cache-miss path would
cross the cap fail as OOM without touching musaMalloc, matching the
`torch.musa.set_per_process_memory_fraction` contract. The cap is a
hard ceiling only on NEW driver allocations: already-live blocks stay
put even if the limit is later lowered, so a misconfigured cap can
never tear down an in-flight tensor.
* OOM diagnostic. Every allocation failure (limit-would-be-exceeded
or `musaMalloc` returned error) builds a multi-line message snapshot
with device, requested bytes, reserved / in-use / peak, segments,
cache hit/miss counters, the configured limit, and driver
free/total. The message is stashed on the allocator for programmatic
access via `GetLastOomMessage()`. Setting `TF_MUSA_ALLOC_VERBOSE_OOM=1`
additionally mirrors the message to stderr every time — this is the
equivalent of `torch_musa`'s `TORCH_MUSA_OOM_VERBOSE`.
* SP_StreamExecutor.get_allocator_stats wiring. In caching mode the
SE callback now fills `SP_AllocatorStats` from
`DeviceCachingAllocator::GetStats()` so tf.config.experimental
get_memory_info sees real numbers; passthrough mode still returns
false and lets TF's BFC populate its own stats.
Python surface on `tensorflow_musa._C`:
* `_device_segment_snapshot(ordinal=0)` → list[dict]
* `_device_set_memory_fraction(fraction, ordinal=0)` → int bytes
* `_device_set_memory_limit_bytes(bytes, ordinal=0)` → None
* `_device_reset_peak_stats(ordinal=0)` → None
* `_device_last_oom_message(ordinal=0)` → str
* `_device_memory_usage(ordinal=0)` → (free, total)
* `_device_allocator_stats(ordinal=0)` now returns two new keys,
`limit_bytes` and `total_device_bytes`, alongside the pre-C3 stats
Validated end-to-end with TF_MUSA_ALLOC_VERBOSE_OOM=1 and a deliberately
tiny 1 MiB limit: the allocator refuses a 64 MiB request, sets
`oom_events=1`, mirrors the diagnostic to stderr, and TF surfaces the
failure as `ResourceExhaustedError` to the caller. Test suite grew
from 9 to 11 checks covering the new schema keys, segment snapshot
invariants, and the fraction/limit round trip.
Introduces the VMM-backed expandable-segments allocation path behind
the TF_MUSA_ALLOC_CONF=expandable_segments:true env flag. Opt-in and
gated on runtime driver capability, so the existing musaMalloc path
remains the default.
* driver_api.{h,cc}: dlopen libmusa.so + dlsym for the VMM surface
(muMemAddressReserve/Create/Map/SetAccess/Unmap/Release/
AddressFree/GetAllocationGranularity) plus muDeviceGetAttribute
and muGetErrorString. Missing symbols disable VMM gracefully.
Per-device capability probe forces driver init via musaSetDevice
before querying MU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_
SUPPORTED, caching only positive results.
* allocator_config.{h,cc}: parser for TF_MUSA_ALLOC_CONF with the
torch_musa key:value syntax. Only expandable_segments is honored
in this commit; max_split_size_mb / roundup_power2_divisions /
garbage_collection_threshold are parsed and exposed for forward
compatibility.
* expandable_segment.{h,cc}: per-segment RAII wrapper that rounds
the request to driver granularity, reserves VA, creates a pinned
handle, maps, and grants R/W access on the owning device. Dtor
unwinds in reverse. Driver failures surface as a null factory
return; callers (the caching allocator) then fall back to
musaMalloc transparently.
* caching_allocator: Block gains an `expandable` pointer;
AllocateSegmentUnlocked routes to ExpandableSegment::Create
when the env flag + driver capability agree, and EmptyCache
dispatches release to the VMM unmap path or musaFree based on
that pointer. Segment snapshots now report is_expandable
accurately.
* _C Python surface: _vmm_available / _vmm_supported /
_vmm_granularity / _allocator_config diagnostics, no new
allocation entry points (those land with the pybind C6 surface).
* test/test_vmm_expandable_segments.py: schema + probe contracts
plus a subprocess-isolated end-to-end that asserts every live
segment is VMM-backed when the env flag is on and the device
supports VMM, and that the flag is a no-op when off.
Adds a narrow Python surface (tensorflow_musa.memory / .device) over the _C extension: - python/_ext.py: package-aware, cached lazy loader for _C, resolving both installed-wheel and in-tree build/ layouts so the wrappers work under pytest, pip install, and ad-hoc source-tree runs. - python/memory.py: empty_cache / memory_allocated / memory_reserved / max_memory_allocated / reset_peak_memory_stats / memory_stats / set_per_process_memory_fraction / mem_get_info / get_allocator_backend. Matches the shape of torch.musa.memory.* where it doesn't conflict with TF idioms; deliberately omits Stream/Event/MemPool per plan §6.3. - python/device.py: device_count / current_device / get_device_name / is_available. Device discovery goes through TF so MUSA_VISIBLE_DEVICES and plugin visibility stay consistent with what TF actually schedules. - python/__init__.py: re-exports the new submodules without forcing _C to load at import time (the loader is lazy), keeping the package import cheap and usable on CPU-only hosts. - test/test_python_memory_api.py: shape tests that work without MUSA hardware plus an end-to-end subprocess test that drives TF traffic and observes it through the new wrappers.
Adds the optional diagnostic tier called out in plan §5.6 / §6.3:
- python/snapshot.py
- memory_snapshot(device=None) -> dict: composes stats + segments +
driver free/total + parsed TF_MUSA_ALLOC_CONF + VMM probes +
last-OOM message + (when active) sampled history, all from
existing _C entry points. No new native ABI.
- _dump_snapshot(path, device=None): atomic JSON write via
<path>.tmp + os.replace so file-watchers see either the old or
the new file, never a truncated one.
- _record_memory_history(enabled, max_entries=1024, interval_ms=50):
daemon-thread ring-buffer sampler over _device_allocator_stats.
Sampling (vs alloc/free hooks) keeps the native contract flat and
the hot path untouched; the note at §4.2 S5 explicitly allows
"size+stream+event" snapshots without symbolicated stacks for
the first drop.
- python/memory.py: re-exports the three helpers so the plan's public
API lives in one place.
- test/test_memory_snapshot.py: shape checks, JSON round-trip,
ring-buffer trim, monotonic timestamps, and confirmation that the
sampler is off by default.
Introduces the three benchmark scripts called out in plan §5.5 (task 7-a) plus the CI regression backstop (task 7-b): - benchmark/bench_h2d.py: pure H2D / D2H throughput sweep across a size range (4K..16M by default). Uses subprocess isolation so every run starts with a cold allocator; reports mean / p50 / p95 / GB/s. - benchmark/bench_resnet.py: 100-step synthetic training loop with ResNet50 when tf.keras is healthy, falling back to a hand-rolled conv stack when keras's dtensor import is broken (common on TF 2.6 with a stand-alone keras wheel). Hand-rolled SGD avoids pulling in tf.optimizers; conv shapes chosen so MUSA's symmetric-padding requirement is met. Reports mean / p50 / p95 step time plus peak allocator footprint. - benchmark/bench_alloc_churn.py: mixed-size allocation storm that exercises the caching allocator's split/merge paths. Asserts four invariants at the end (no leaks, no spurious OOM, empty_cache releases >=95% of reserved, cache_hits > 0) and exits non-zero on any violation. - test/test_bench_alloc_churn.py: wraps the churn bench in a cheap pytest (iters=20, batch=8) so CI can run it as part of the regular suite and fail fast when the allocator regresses.
First deliverable in the plan's "optional" phase (§3.8, §4.3 C3):
multi-GPU peer-access capability probes and enable control, with a
process-wide cache so repeat queries stay cheap.
- musa_ext/mu/device/peer_access.{h,cc}: leaked-singleton PeerAccess
that wraps musaDeviceCanAccessPeer + musaDeviceEnablePeerAccess.
Saves/restores the caller's current device around the enable call
so the probe is side-effect free; squashes ErrorPeerAccessAlreadyEnabled
to success for idempotency; exposes a Snapshot() view of observed
pairs for diagnostics.
- musa_ext/python/_C.cpp: four new entries (_peer_device_count,
_peer_can_access, _peer_enable_access, _peer_access_snapshot).
- python/device.py: user-facing can_access_peer / enable_peer_access /
peer_access_snapshot. Documents that the plugin's memcpy_dtod path
does NOT yet auto-dispatch peer copies; these entry points give
multi-GPU users a way to enable access manually (or drive their own
musaMemcpyPeerAsync call).
- test/test_peer_access.py: validates self-access, argument handling,
out-of-range behavior, and — when >= 2 MUSA devices are present —
the cached lookup and idempotent enable paths.
- CMakeLists.txt: registers peer_access.cc in MUSA_CORE_SOURCES so
both libmusa_core.so and _C.so see the same singleton, matching the
plan §6.1 "shared state across libraries" architecture.
This commit is intentionally decoupled from the main allocator path;
follow-ups can add peer-aware memcpy dispatch and extend VMM
expandable segments with peer muMemSetAccess lists, both listed in
plan §3.8 but gated on real multi-GPU demand.
`DeviceCachingAllocator::Allocate()` used to release the allocator mutex between two critical sections on the cache-hit path: the first erased the block from `free_blocks` and called `MaybeSplitLocked`, and the second re-acquired the lock to set `block->allocated = true` and emplace the block into `active_blocks`. In the gap the block was still linked into its segment's `prev`/`next` chain but looked free (`allocated == false`) and lived in neither map. A concurrent `Free()` on an address-adjacent block would then enter `MergeNeighborsLocked`, see `!block->allocated`, absorb our just-taken block and `delete` it, leaving a dangling pointer in `free_blocks` and causing the next cache-hit lookup to return a `Block*` whose memory had been recycled. The concrete crash signatures on a multi-threaded TF inference run: free(): double free detected in tcache 2 Check failed: IsAligned() ptr = 0x<non-16B-aligned> Fix: fold the "commit to active" step (`block->allocated = true`, `active_blocks.emplace`, in_use / peak accounting) into the first critical section for both the cache-hit and cache-miss paths, so a concurrent merge can never observe this block as mergeable. The driver call (`musaMalloc` / VMM reserve) still happens with the lock released, which was the only reason the second critical section existed in the first place. Adds `test_concurrent_allocate_free_does_not_corrupt_blocks`, a subprocess-isolated stress test that drives four host threads through differently-shaped matmuls against the same per-device allocator and asserts the workload runs to completion with `in_use_bytes == 0` at the end.
Prepare the repo for open-source release:
* README (zh + en) slimmed to a professional overview: condense the
feature list, drop internal design references and verbose debugging
sections, and point memory tuning at the new public env-var doc.
* `docs/` is now user-facing only and contains just
`environment-variables.md`, a pure reference for every `TF_MUSA_*`
knob grouped by subsystem (device allocator, host pinned, H2D
staging, auto-pin, event pool, graph optimizer).
* Internal design / debugging / compatibility docs move under
`internal-docs/`:
- `architecture-and-memory.md` (new): end-to-end architecture and
memory scheduling reference (retained for internal use).
- `memory-optimization.md`: history of the torch_musa-inspired
memory work, with the env-var table aligned to current defaults.
- `DEBUG_GUIDE.md`, `tf-compat-matrix.md`: moved with git history.
- `README.md`: explains the split and how to exclude the
directory from open-source releases.
No code changes; wheel / build layout is unaffected.
`./build.sh debug` used to flip on a musaEvent-based per-kernel timing scope gated behind the `MUSA_KERNEL_DEBUG` compile flag. The implementation had drifted out of sync with the rest of the plugin, its env-var surface (`MUSA_TIMING_KERNEL_LEVEL` / `_NAME` / `_STATS`) was never finished, and it duplicated device-side timing that TF profiler already exposes. Remove the feature outright: * `musa_ext/utils/logging.h`: delete the ~760 LOC timing scope (`KernelTimingScope`, `KernelTimingConfig`, `KernelTimingStatsRegistry` and the musaEvent lifecycle helpers) that lived under `#ifdef MUSA_KERNEL_DEBUG`. The existing kernel call sites (`MUSA_KERNEL_TIMING_GUARD(ctx)`, `MUSA_KERNEL_TRACE_START/END(...)`) are preserved as unconditional no-op macros so the 50+ kernel files that use them keep compiling without churn, and we retain a clean hook point if per-kernel timing is reintroduced later. * `CMakeLists.txt`: drop the `MUSA_KERNEL_DEBUG` CMake option, the `if(MUSA_KERNEL_DEBUG)` flag branch (including the stale `-DMUSA_KERNEL_DEBUG` define), and the matching status line. * `setup.py`: remove `-DMUSA_KERNEL_DEBUG=OFF` from the CMake invocation. * `build.sh`: stop threading `MUSA_KERNEL_DEBUG` through to CMake and reword the `debug` build mode to reflect what it actually does now — `CMAKE_BUILD_TYPE=Debug`, i.e. unoptimized with `-g -O0` for gdb. * `README.md` / `README.en.md`: update the Debug-mode row accordingly. * `internal-docs/DEBUG_GUIDE.md`: drop the kernel-timing and memory- coloring sections (the latter was doc-only and never had backing code), prune `MUSA_TIMING_KERNEL_*` from the env-var table, and redirect the performance-troubleshooting checklist at TF profiler plus the public `docs/environment-variables.md` reference.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.