Skip to content

Conversation

@pctablet505
Copy link
Collaborator

This pull request refactors how export options are passed to the model.export() method for different formats in Keras. Instead of using generic keyword arguments, format-specific options are now grouped into dedicated dictionaries (saved_model_kwargs, onnx_kwargs, litert_kwargs, openvino_kwargs). This improves clarity, type safety, and maintainability of the export API. The test suite and documentation have been updated to reflect these changes.

API refactor: format-specific export options

  • The model.export() method signature in model.py now accepts dedicated keyword arguments for each export format (saved_model_kwargs, onnx_kwargs, litert_kwargs, openvino_kwargs) instead of generic **kwargs.
  • The docstring and usage examples for model.export() have been updated to show how to use the new format-specific kwargs, including detailed explanations and code samples for each format. [1] [2] [3]

Implementation changes

  • The internal calls to format-specific export functions now unpack the corresponding kwargs dictionary, ensuring only relevant options are passed to each exporter.

Test suite updates

  • All tests in litert_test.py that previously passed optimizations and other options directly to model.export() now use the new litert_kwargs dictionary to pass these options. [1] [2] [3] [4] [5] [6] [7] [8]

pctablet505 and others added 30 commits May 6, 2025 10:25
Corrected indentation in doc string
Fixed issue with passing a single image without batch dimension.
Test case for unbatched inputs
Testcase for checking both unbatched and batched single image inputs.
There was a bug, and it was causing cycle in graph.
removed the use of tree.map_structure
Enhanced the _can_use_flash_attention function to provide more detailed
error messages when flash attention compatibility checks fail.

Changes:
- Replace generic exception catching with specific error propagation
- When raise_error=True, directly re-raise original exceptions from
  check_layout() and check_is_flash_attention() functions
- Preserve detailed error context from JAX internal validation functions
- Maintain existing behavior when raise_error=False (returns False)

This improves debugging experience by surfacing specific technical details
about tensor layout incompatibilities, cuDNN version requirements, and
other flash attention compatibility issues.

Relates to keras-hub PR keras-team#2257 and addresses flash attention debugging needs.
…sh_attention`

Changes:
- Add missing q_offsets=None and kv_offsets=None parameters to check_layout()
  call to match updated JAX function signature
- Replace bare `except:` with `except Exception as e:` and `raise e` to
  preserve detailed error messages from JAX validation functions
- Maintain existing fallback behavior when raise_error=False

This resolves compatibility issues with newer JAX versions and improves
debugging experience by surfacing specific technical details about
flash attention compatibility failures.
Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.
pctablet505 and others added 19 commits November 17, 2025 16:39
Refactored the fallback TFLite conversion method to use a direct tf.function approach instead of a tf.Module wrapper, simplifying the conversion logic. Added a 'verbose' parameter to export_litert and LiteRTExporter for progress messaging. Improved converter kwargs handling to only apply known TFLite settings.
Eliminates the 'verbose' parameter from export_litert and LiteRTExporter, simplifying the API and reducing unnecessary options for export progress messages.
Update LiteRTExporter to always enable resource variables during TFLite conversion, as Keras 3 only supports resource variables. Simplify conversion logic by removing strategy loop and error handling for unsupported conversion paths.
Deleted the _has_dict_inputs method from the LiteRTExporter class in litert.py as it is no longer used. This helps clean up the code and improve maintainability.
Simplifies and enforces stricter validation for converter kwargs in LiteRTExporter. Unknown attributes now raise ValueError instead of being ignored, and the method no longer maintains a list of known attributes, relying on attribute existence checks.
Eliminates the fallback method that used SavedModel as an intermediate step for TFLite conversion. Now, if direct conversion fails, a RuntimeError is raised with a helpful message, simplifying the export logic and error handling.
Replaces tf.keras.layers and tf.keras.Model references with locally imported layers and models from keras.src. This improves consistency and may help with modularity or compatibility within the keras.src namespace.
Improves input signature inference and adapter creation for models with nested input structures (dicts, lists, etc.) in LiteRTExporter. Moves TensorSpec creation logic to export_utils and updates TFLayer to use tree.map_structure for save spec generation. Removes legacy dict-specific input signature inference and centralizes input structure handling for TFLite conversion.
Included the ai-edge-litert package in requirements.txt to support new functionality or dependencies.
Improves analysis of input signatures by unwrapping single-element lists for Functional models and consistently using the correct structure for input handling. Also updates input layer naming to prefer spec.name when available, ensuring more accurate input identification.
Brings in latest upstream changes from keras-team/master including:
- Support PyDataset in Normalization layer adapt methods
- Fix Torch output_padding constraint for ConvTranspose layers
- Improve error messages and validation
- Various bug fixes and improvements from upstream

This keeps the export branch up-to-date with the latest Keras codebase
while preserving all LiteRT export functionality.
Brings in latest upstream changes including:
- OrbaxCheckpoint callback for data-centric saving and restore
- ldexp function in keras.ops
- Fix test failures when nnx is enabled
- PyDataset support in Normalization layer adapt methods
- Torch output_padding constraint fixes for ConvTranspose layers
- OpenVINO improvements and updates
- TerminateOnNaN raise_error option
- TPU testing workflow
- Various bug fixes and improvements

Resolved conflict in module_utils.py by keeping both:
- litert module (for LiteRT export functionality)
- ocp module (for Orbax checkpoint support)
- Changed Model.export() to accept format-specific kwargs dictionaries:
  saved_model_kwargs, onnx_kwargs, litert_kwargs, openvino_kwargs
- Replaced flat **kwargs with explicit named parameters for better clarity
- Updated all LiteRT test cases to use litert_kwargs={} for options
- Updated keras-hub test helper to use litert_kwargs parameter
- Fixed line length violation in docstring (>80 chars)
Conflicts resolved:
- Model.export() API conflicts: Kept grouped kwargs approach (saved_model_kwargs, onnx_kwargs, litert_kwargs, openvino_kwargs)
- LiteRT test conflicts: Kept updated tests using new litert_kwargs API instead of flat kwargs

This merges latest upstream changes while maintaining the new export API design.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @pctablet505, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the model.export() API in Keras to improve how export options are handled for different model formats. By replacing a generic **kwargs parameter with explicit, format-specific dictionaries like saved_model_kwargs, onnx_kwargs, litert_kwargs, and openvino_kwargs, the API becomes more intuitive, type-safe, and easier to maintain. This change ensures that only relevant options are passed to each exporter, and the updated documentation and test suite reflect these improvements, providing clearer guidance and robust validation for model export functionalities.

Highlights

  • API Refactor: The model.export() method now uses format-specific keyword arguments (saved_model_kwargs, onnx_kwargs, litert_kwargs, openvino_kwargs) instead of generic **kwargs.
  • Improved Clarity and Type Safety: This change enhances the clarity, type safety, and maintainability of the export API by explicitly grouping options for each format.
  • Documentation Updates: The docstring for model.export() has been updated with detailed explanations and code examples for using the new format-specific kwargs.
  • Test Suite Adaptation: All relevant tests in litert_test.py have been modified to pass LiteRT-specific options via the new litert_kwargs dictionary.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great refactoring of the model.export() method. Grouping format-specific options into dedicated dictionaries like saved_model_kwargs and litert_kwargs significantly improves the API's clarity, safety, and maintainability, moving away from an opaque **kwargs implementation. The updates to the documentation with new examples are also very valuable. I've found one minor issue in a new documentation example that makes it not runnable, and I've provided a suggestion to fix it. Overall, this is an excellent change.

@codecov-commenter
Copy link

codecov-commenter commented Dec 10, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.30%. Comparing base (46813a3) to head (babf7da).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21910   +/-   ##
=======================================
  Coverage   76.30%   76.30%           
=======================================
  Files         580      580           
  Lines       60031    60031           
  Branches     9433     9433           
=======================================
  Hits        45805    45805           
  Misses      11750    11750           
  Partials     2476     2476           
Flag Coverage Δ
keras 76.17% <ø> (ø)
keras-jax 62.12% <ø> (ø)
keras-numpy 57.32% <ø> (+<0.01%) ⬆️
keras-openvino 34.30% <ø> (ø)
keras-torch 63.22% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@pctablet505 pctablet505 requested a review from fchollet December 10, 2025 05:46
@hertschuh
Copy link
Collaborator

The issue is that this is a breaking change. Basically we can't do that because it will break existing customers of the API.

Additionally, it forces you to pass a dict instead of explicit arguments.

That is the problem that this is solving? That you need arguments for both saved_model and litert at the same time?

Because this change doesn't improve type safety, if fact it does the reverse, it allows you to use dictionary keys that are not valid argument names.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants