Skip to content

Add option to memory map .ORT model loads#28164

Merged
tianleiwu merged 12 commits intomainfrom
user/kevintaha/memMapOrt
May 2, 2026
Merged

Add option to memory map .ORT model loads#28164
tianleiwu merged 12 commits intomainfrom
user/kevintaha/memMapOrt

Conversation

@Kevin-Taha
Copy link
Copy Markdown
Contributor

@Kevin-Taha Kevin-Taha commented Apr 21, 2026

Addressing issue #25524 (MS internal: 60577894)

Today, the closest method callers have to loading models from a shared resource is by mapping the model themselves and using use_ort_model_bytes_directly - this puts the responsibility on the caller to ensure the validity of the mapping as well. These changes introduce use_memory_mapped_ort_model, a session option for using memory-mapped I/O to load ORT format models directly inside OnnxRuntime. The mapping in this case is owned by the InferenceSession. The changes to implement this are simple and minimal and use ORT's existing platform-agnostic memory mapping helpers, and if we choose to make this the default behavior could mean automatic memory savings for multi-process usage.

Note about memory implications & sharing model bytes:

The reality of this change is that using use_memory_mapped_ort_model alone doesn't have a long-running memory usage advantage because ORT will ultimately copy the model bytes from the mapped pages into Tensors. Using it in coordination with session.use_ort_model_bytes_for_initializers ensures that that initializers point directly to the flatbuffer bytes and avoids the extra copy. This would be the expected usage for multi-process sharing of a single model. This introduces questions around what the default behavior should be - the changes I made in this PR are conservative and retain all existing defaults at this time.

Changes

  • onnxruntime_session_options_config_keys.h — New session.use_memory_mapped_ort_model config key
  • inference_session.h — Added Env::MappedMemoryPtr member to hold the file mapping; updated existing comments to document the mmap path
  • inference_session.cc — New LoadOrtModelBytesMapped() static function; updated LoadOrtModel(PathString) to check config and use mmap; updated Initialize() cleanup to release the mapping; updated comment
    on initializer gating to note mmap case
  • ort_model_only_test.cc — Two new tests: LoadOrtFormatModelMemoryMapped and LoadOrtFormatModelMemoryMappedWithInitializersFromMap
  • Also checking in a benchmarking tool, benchmark_mmap_ort.py, just for preservation, but this is optional and can be omitted.
  • Added a flag to the perf tests used by the benchmark to hold onto the session for a specified amount of time - useful for measuring memory sharing changes. We can revert these and exclude the benchmark if they are not desired for check-in.

Benchmark Examples

Note that the benchmark is largely written by GHCP and may not be perfect, but I've validated some of its results.
Single-Proc
Here is a sample result from a single-process benchmark using resnet50 (converted to ORT format). Note that these measure peaks during construction and not end-states, and the measurements may be imperfect.
python tools/python/benchmark_mmap_ort.py --perf-test build\Windows\Release\Release\onnxruntime_perf_test.exe --model resnet50.ort --iterations 15

Configuration Session Creation (ms) Peak Private Commit (MB) Peak Working Set (MB) Session vs baseline Private vs baseline
.ort standard load (baseline) 193.13 222.9 235.9
.ort memory-mapped load 120.95 125.7 236.1 -37.4% -43.6%
.ort mmap + direct initializers 14.87 109.6 120.6 -92.3% -50.8%

Multi-Proc

The multi-proc benchmark shows that total memory bandwidth gains for shared models can only be obtained alongside use_ort_model_bytes_for_initializers_

Configuration (4 processes) Total Private (MB) Total Working Set (MB) Private vs baseline
.ort standard load (baseline) 462.6 519.0
.ort memory-mapped load 462.1 518.5 -0.1%
.ort mmap + direct initializers 98.2 187.8 -78.8%

Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Comment thread tools/python/benchmark_mmap_ort.py Fixed
Kevin Taha and others added 4 commits April 23, 2026 13:37
…ap benchmarking

- Clean up benchmark_mmap_ort.py: remove unused code, simplify multi-process
  approach to use native perf_test processes instead of Python wrappers
- Add --hold_ms_after_session_creation flag to onnxruntime_perf_test to keep
  sessions alive for multi-process memory measurement
- Print SESSION_READY marker when holding so benchmark script knows when to sample

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@Kevin-Taha Kevin-Taha force-pushed the user/kevintaha/memMapOrt branch from 8c170c9 to 50bc19e Compare April 23, 2026 20:37
@Kevin-Taha Kevin-Taha marked this pull request as ready for review April 27, 2026 20:13
Comment thread onnxruntime/core/session/inference_session.cc
@yuslepukhin
Copy link
Copy Markdown
Member

yuslepukhin commented Apr 28, 2026

Thanks for a great PR. I will address a few things here and will ask @skottmckay for review.

@yuslepukhin
Copy link
Copy Markdown
Member

The comment about copying the initializers is not entirely correct. We strive to create tensors on top of the flatbuffer memory.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an opt-in session configuration to load .ort (ORT-format) models via memory-mapped I/O, so the InferenceSession owns the mapping and can optionally keep initializer tensors referencing the mapped flatbuffer bytes.

Changes:

  • Introduces session.use_memory_mapped_ort_model and wires it into ORT-format model loading.
  • Keeps the memory mapping alive for the session when using direct initializers, and releases it after Initialize() otherwise.
  • Adds unit tests for memory-mapped .ort loading and extends onnxruntime_perf_test plus a Python benchmark helper for multi-process memory measurement.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tools/python/benchmark_mmap_ort.py Adds a developer benchmark script to compare standard vs mmap ORT model loads (single- and multi-process).
onnxruntime/test/perftest/test_configuration.h Adds hold_ms_after_session_creation to keep perf-test processes alive for memory measurements.
onnxruntime/test/perftest/main.cc Implements SESSION_READY signaling + sleep when -n is used with the new hold flag.
onnxruntime/test/perftest/command_args_parser.cc Adds --hold_ms_after_session_creation flag wiring and help text.
onnxruntime/test/framework/ort_model_only_test.cc Adds coverage for .ort loading via mmap (with/without initializer-bytes usage).
onnxruntime/core/session/inference_session.h Stores an Env::MappedMemoryPtr in InferenceSession for mmap-backed ORT-format model bytes.
onnxruntime/core/session/inference_session.cc Adds mmap load path for ORT-format models and releases mapping during Initialize() cleanup when applicable.
include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h Defines the new session.use_memory_mapped_ort_model config key and updates related docs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/session/inference_session.cc
Comment thread tools/python/benchmark_mmap_ort.py
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h Outdated
Comment thread onnxruntime/core/session/inference_session.cc
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

Clean, well-scoped feature. The implementation cleverly leverages existing infrastructure — ort_format_model_bytes_data_holder_.empty() naturally enables the initializer-from-bytes path for mmap without additional flags. The cleanup in Initialize() correctly resets ort_format_model_mapped_memory_ when initializers don't use the bytes. reinterpret_cast is correct for char*uint8_t* (agreeing with yuslepukhin's comment).

A couple of suggestions below. The doc comment style (// vs ///) and benchmark timeout concerns from the earlier automated review are also worth addressing.

Comment thread onnxruntime/core/session/inference_session.cc
Comment thread onnxruntime/test/framework/ort_model_only_test.cc
chwarr
chwarr previously requested changes Apr 30, 2026
Comment thread onnxruntime/core/session/inference_session.cc
@yuslepukhin yuslepukhin requested review from chwarr and Copilot April 30, 2026 21:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread tools/python/benchmark_mmap_ort.py Outdated
Comment thread onnxruntime/test/framework/ort_model_only_test.cc Outdated
Comment thread onnxruntime/test/perftest/command_args_parser.cc
yuslepukhin and others added 2 commits April 30, 2026 14:54
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. I found one remaining measurement issue in the helper script: on POSIX, the multi-process private metric falls back to RSS, which double-counts shared mmap pages and can hide the benefit this benchmark is meant to measure. I left that inline.

I did not open a duplicate comment for the existing file sizing/opening thread; I replied there with what still seems unresolved after the latest changes.

Comment thread tools/python/benchmark_mmap_ort.py Outdated
Copy link
Copy Markdown
Member

@chwarr chwarr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I won't have time for an in depth review for a few weeks. Please don't block waiting for me.

@yuslepukhin yuslepukhin requested a review from tianleiwu May 1, 2026 17:06
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found one low-level bounds-check issue in the new file-mapping validation. The mmap loading flow and tests otherwise look reasonable to me, but the requested-end arithmetic should be made overflow-safe before merge.

Comment thread onnxruntime/core/platform/posix/env.cc Outdated
Comment thread onnxruntime/core/platform/windows/env.cc Outdated
@yuslepukhin yuslepukhin requested a review from tianleiwu May 2, 2026 00:15
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All previously raised concerns have been addressed in the latest commits:

  • Overflow-safe bounds check: Both POSIX and Windows MapFileIntoMemory now use SafeInt<size_t>(offset) + length instead of raw arithmetic, preventing silent wrap-around before the file-size comparison.
  • POSIX memory measurement: The benchmark script now uses memory_full_info().uss on POSIX for accurate private memory accounting, with RSS fallback properly labeled.
  • Robustness: Timeout handling, try/finally cleanup, and errors="replace" for decode safety.

The implementation is clean and well-scoped. Session owns the mapping via Env::MappedMemoryPtr (RAII), lifecycle is correct (freed in Initialize() unless initializers reference the mapped bytes), and the test coverage is adequate.

@tianleiwu tianleiwu dismissed chwarr’s stale review May 2, 2026 06:08

see chwar's comments above:
I won't have time for an in depth review for a few weeks. Please don't block waiting for me.

@tianleiwu tianleiwu merged commit 9d1492a into main May 2, 2026
85 of 90 checks passed
@tianleiwu tianleiwu deleted the user/kevintaha/memMapOrt branch May 2, 2026 06:08
sanaa-hamel-microsoft added a commit that referenced this pull request May 4, 2026
This cherry-picks the following commits for the release:

| Commit ID | PR Number | Commit Title |
|-----------|-----------|-------------|
| 9d1492a | #28164 | Add option to memory map .ORT model loads |

Co-authored-by: Kevin Taha <tahakevin@gmail.com>
Co-authored-by: Kevin Taha <kevintaha@microsoft.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Dmitri Smirnov <dmitrism@microsoft.com>
Co-authored-by: Dmitri Smirnov <yuslepukhin@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants