-
Notifications
You must be signed in to change notification settings - Fork 574
feat(pt): add compression support for se_e3_tebd #4992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds compression support for the se_e3_tebd (SE_T_TEBD) descriptor type. The compression functionality allows tabulation of the embedding network to improve computational efficiency during inference.
- Adds "T_TEBD" descriptor type support to the tabulation system
- Refactors variable names from
xx
/vv
/tt
to more descriptive names likemesh
/value
/stride
- Implements compression methods for the SE_T_TEBD descriptor block
Reviewed Changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
examples/water/.gitignore | Adds *.hdf5 to ignored files |
deepmd/utils/tabulate.py | Adds T_TEBD support and refactors variable naming throughout tabulation logic |
deepmd/tf/utils/tabulate.py | Updates parameter names from xx to mesh in TensorFlow tabulation |
deepmd/pt/utils/tabulate.py | Adds T_TEBD descriptor support and updates PyTorch tabulation implementation |
deepmd/pt/model/descriptor/se_t_tebd.py | Implements compression functionality for SE_T_TEBD descriptor |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds SE_T_TEBD tabulation and compression: new compression state and enable_compression APIs in SeTTebd descriptor classes, shared spline/table generation and T_TEBD support in tabulate utilities, new CPU/GPU tabulation kernels and autograd bindings, multi-device dispatch, and unit tests for the TEBD tabulate op. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python
participant Desc as SeTTebd descriptor
participant Tab as DPTabulate / tabulate utils
participant Lib as C++ tabulate (CPU/GPU)
Note over Py,Desc: Compression enablement flow
Py->>Desc: enable_compression(params)
Desc->>Tab: build shared table (table, table_info)
Tab->>Lib: tabulate_fusion_se_t_tebd(table, table_info, em_x, em, last_layer_size)
Lib-->>Tab: descriptor tensor (forward) / dy_dem_x (backward)
Tab-->>Desc: tabulated embeddings / table metadata
Desc-->>Py: forward uses compressed (tabulated) embedding when compress==True
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (8)
🚧 Files skipped from review as they are similar to previous changes (2)
🧰 Additional context used📓 Path-based instructions (1)**/*.py📄 CodeRabbit inference engine (AGENTS.md)
Files:
🧬 Code graph analysis (5)deepmd/pt/utils/tabulate.py (2)
source/tests/pt/test_tabulate_fusion_se_t_tebd.py (2)
source/op/pt/tabulate_multi_device.cc (2)
deepmd/pt/model/descriptor/se_t_tebd.py (6)
source/lib/src/gpu/tabulate.cu (1)
🪛 Ruff (0.13.1)deepmd/pt/model/descriptor/se_t_tebd.py552-552: Avoid specifying long messages outside the exception class (TRY003) 572-574: Avoid specifying long messages outside the exception class (TRY003) 577-577: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
🔇 Additional comments (31)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note:
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/utils/tabulate.py (1)
314-318
: Off-by-one in stride window selection (T/T_TEBD) corrupts spline segmentsUsing +1 shifts the fine-stride window by one interval; the first fine interval remains coarse, breaking coefficients.
Apply:
- start_index = int((lower - extrapolate * lower) / stride1) + 1 + start_index = int((lower - extrapolate * lower) / stride1) end_index = start_index + int((upper - lower) / stride0) stride[start_index:end_index, :] = stride0
🧹 Nitpick comments (4)
deepmd/pt/model/descriptor/se_t_tebd.py (3)
527-603
: Enable-compression flow for TEBD — mostly OK; tighten message and comment
- Logic and wiring to DPTabulate look correct.
- Improve the mode check error to reflect the actual value; update the comment to TEBD.
- if self.tebd_input_mode != "strip": - raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'") + if self.tebd_input_mode != "strip": + raise RuntimeError(f"Compression requires tebd_input_mode='strip' (got '{self.tebd_input_mode}')") ... - # Scale the stride values for SE_T descriptor + # Scale the stride values for TEBD descriptorAs per coding guidelines
1021-1060
: Signature/type hints mismatch for table_configtable_config is used as an indexable sequence (list), but annotated as dict. Align the type hints and docstring.
- def enable_compression( - self, - table_data: dict, - table_config: dict, + def enable_compression( + self, + table_data: dict, + table_config: list[float] | tuple[float, float, float, int], lower: dict, upper: dict, ) -> None: @@ - table_config : dict - Configuration for table compression + table_config : list[float] | tuple[float, float, float, int] + [extrapolate, stride0, stride1, check_frequency]As per coding guidelines
552-577
: Shorten/standardize exception messages (ruff TRY003)A couple of raised messages are verbose. Keep them concise to satisfy TRY003.
- assert not self.se_ttebd.resnet_dt, ( - "Model compression error: descriptor resnet_dt must be false!" - ) + assert not self.se_ttebd.resnet_dt, "resnet_dt must be False for compression" @@ - raise RuntimeError( - "Empty embedding-nets are not supported in model compression!" - ) + raise RuntimeError("Empty embedding nets not supported for compression")Run ruff check . to confirm.
deepmd/utils/tabulate.py (1)
293-304
: _generate_spline_table type hints incorrectstride0/stride1 and extrapolate are floats; current annotations are misleading.
- stride0: int, - stride1: int, - extrapolate: bool, + stride0: float, + stride1: float, + extrapolate: float,As per coding guidelines
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
.gitignore
(1 hunks)deepmd/pt/model/descriptor/se_t_tebd.py
(6 hunks)deepmd/pt/utils/tabulate.py
(9 hunks)deepmd/tf/utils/tabulate.py
(2 hunks)deepmd/utils/tabulate.py
(12 hunks)examples/water/.gitignore
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .
andruff format .
before committing changes to Python code
Files:
deepmd/pt/utils/tabulate.py
deepmd/tf/utils/tabulate.py
deepmd/utils/tabulate.py
deepmd/pt/model/descriptor/se_t_tebd.py
🧬 Code graph analysis (4)
deepmd/pt/utils/tabulate.py (2)
deepmd/utils/tabulate.py (1)
_make_data
(391-408)deepmd/tf/utils/tabulate.py (3)
_make_data
(319-466)_layer_1
(471-473)_layer_0
(468-469)
deepmd/tf/utils/tabulate.py (2)
deepmd/pt/utils/tabulate.py (5)
_make_data
(114-281)_layer_0
(283-286)unaggregated_dy_dx_s
(502-521)unaggregated_dy2_dx_s
(524-550)_layer_1
(288-292)deepmd/tf/utils/compress.py (2)
_layer_0
(59-60)_layer_1
(63-65)
deepmd/utils/tabulate.py (2)
deepmd/pt/utils/tabulate.py (1)
_make_data
(114-281)deepmd/tf/utils/tabulate.py (1)
_make_data
(319-466)
deepmd/pt/model/descriptor/se_t_tebd.py (4)
deepmd/pt/utils/tabulate.py (1)
DPTabulate
(30-441)deepmd/pt/model/descriptor/se_t.py (3)
enable_compression
(284-330)enable_compression
(743-770)serialize
(412-438)deepmd/pt/utils/utils.py (1)
ActivationFn
(175-220)deepmd/pt/model/descriptor/se_atten.py (4)
enable_compression
(427-448)serialize
(780-804)serialize
(892-914)serialize
(1082-1105)
🪛 Ruff (0.13.1)
deepmd/utils/tabulate.py
320-320: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/model/descriptor/se_t_tebd.py
552-552: Avoid specifying long messages outside the exception class
(TRY003)
572-574: Avoid specifying long messages outside the exception class
(TRY003)
577-577: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (7)
deepmd/tf/utils/tabulate.py (1)
318-466
: Mesh rename keeps TensorFlow path alignedThe reshaping/derivative flow remains identical while matching the new mesh terminology, so the TensorFlow tabulator stays in sync with the PT backend. Looks good.
deepmd/pt/utils/tabulate.py (4)
309-311
: Add TEBD descriptor type mapping — OKCorrectly recognizes DescrptSeTTebd as "T_TEBD".
325-326
: Layer-size handling for T_TEBD — OKTreating "T_TEBD" like "Atten" for layer sizing aligns with shared embedding net design.
394-401
: Shared network variables for T_TEBD — OKFetching the single shared embedding network per layer matches the TEBD design.
154-155
: Device inconsistency for torch.ones (duplicate of prior feedback)Use env.DEVICE for consistency with the rest of the tensor placements.
- ) + torch.ones((1, 1), dtype=yy.dtype, device=yy.device) + ) + torch.ones((1, 1), dtype=yy.dtype, device=env.DEVICE) ... - ) + torch.ones((1, 2), dtype=yy.dtype, device=yy.device) + ) + torch.ones((1, 2), dtype=yy.dtype, device=env.DEVICE)Also applies to: 173-174
deepmd/utils/tabulate.py (2)
203-247
: Shared-mesh tabulation for T_TEBD — OKGlobal range + single mesh for the shared geometric net is correct; nspline formula matches the constructed mesh.
Please run ruff check . and a quick smoke build of tables for a small synthetic range to ensure nspline equals len(mesh)-1.
486-505
: Env range update for T_TEBD — OKHandling T_TEBD with the (cos theta)^2 bounds matches the T path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- please remove uncessary changes, such as
.gitignore
- please remove rename in this PR, you should make it a seperate PR.
- please add a UT for your modification such as source/tests/pt/test_tabulate_fusion_se_atten.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
deepmd/pt/utils/tabulate.py (1)
278-281
: Improved variable naming for clarity.The rename from
vv
tovalue
makes the final output variable more descriptive and clear, improving code readability without changing functionality.deepmd/utils/tabulate.py (3)
107-109
: LGTM! Refactored to use unified spline table generation.The renaming from
_build_lower
to_generate_spline_table
with updated parameters improves code clarity and provides a consistent interface for spline table generation across different descriptor types.
291-302
: LGTM! Improved function signature and documentation.The renamed function with updated parameter names (
xx
→mesh
, clearer parameter names) and better documentation improves code maintainability. The signature change from the old_build_lower
to_generate_spline_table
makes the purpose more explicit.
504-512
: LGTM! Cleaner spline switch function implementation.The updated parameter names (
xx
→x
) and simplified variable names improve code readability. The function logic remains correct and maintains the same mathematical behavior.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/tabulate.py
(8 hunks)deepmd/utils/tabulate.py
(10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .
andruff format .
before committing changes to Python code
Files:
deepmd/pt/utils/tabulate.py
deepmd/utils/tabulate.py
🧬 Code graph analysis (1)
deepmd/pt/utils/tabulate.py (2)
deepmd/tf/utils/tabulate.py (1)
_layer_1
(471-473)deepmd/pt/model/descriptor/se_t_tebd.py (1)
DescrptSeTTebd
(78-602)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (17)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
🔇 Additional comments (15)
deepmd/pt/utils/tabulate.py (8)
69-69
: LGTM! Added T_TEBD to supported descriptor types.The addition of "T_TEBD" to the supported descriptor types correctly enables tabulation support for the se_e3_tebd descriptor, which was previously not supported for model compression according to the documentation.
154-154
: Device consistency maintained correctly.
163-180
: LGTM! Correct residual handling for TEBD-style descriptors.The changes properly handle residual connections in the first layer for TEBD descriptors, matching the pattern used elsewhere in the codebase for similar descriptor types. The device-aware tensor creation is correctly implemented.
173-173
: Device consistency maintained correctly.
234-247
: LGTM! Correct residual handling for TEBD-style descriptors in deeper layers.The changes extend the residual connection handling to deeper layers, maintaining consistency with the first layer implementation. The variable naming with
residual
improves code clarity.
309-311
: LGTM! Added TEBD descriptor type recognition.The addition correctly maps
DescrptSeTTebd
to the "T_TEBD" type, enabling proper classification for the new descriptor type.
325-326
: LGTM! Correct layer sizing for TEBD descriptors.Treating "T_TEBD" similarly to "Atten" for layer sizing is appropriate since both use shared network architectures, as evidenced by the network variable handling.
394-400
: LGTM! Shared embedding network for TEBD descriptors.The implementation correctly handles the shared embedding network approach for T_TEBD descriptors, where a single network is used for all type pairs. This aligns with the architectural design described in the relevant code snippets.
deepmd/utils/tabulate.py (7)
148-150
: LGTM! Consistent use of unified spline generation.The function call correctly uses the new
_generate_spline_table
interface, maintaining the same functionality with improved naming.
188-199
: LGTM! Consistent spline generation for T descriptor.The updated call maintains the same functionality while using the new unified interface.
200-245
: LGTM! Well-implemented T_TEBD descriptor support.The T_TEBD implementation correctly:
- Calculates global ranges across all types for the shared network
- Creates a unified input grid based on global bounds
- Generates spline coefficients once for the shared geometric network
- Uses appropriate naming convention ("filter_net")
This approach aligns with the shared network architecture of TEBD descriptors and follows the established patterns for other descriptor types.
280-282
: LGTM! Consistent spline generation for R descriptor.The updated call maintains functionality while using the new unified interface.
309-316
: Critical stride indexing issue needs attention.The existing review comment identifies a critical off-by-one error in the stride window selection that affects both T and T_TEBD descriptor types. This issue corrupts spline segments and should be addressed.
436-436
: LGTM! Correct table size handling for TEBD.Including "T_TEBD" with "Atten" and "AEbdV2" for table size calculation is appropriate since they all use shared network architectures requiring only one table.
490-490
: LGTM! Correct environment matrix range handling for TEBD.Including "T_TEBD" with "T" for environment matrix range calculation is appropriate since they handle similar geometric (angular) information and use the same mathematical formulation for computing bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/se_t_tebd.py (1)
527-603
: Check the scaled stride values for SE_T_TEBD compression.The stride values are scaled by 10x (lines 587-588) to match SE_T behavior. However, the comment on line 586 says "Scale the stride values for SE_T descriptor" rather than SE_T_TEBD. Consider updating the comment for clarity.
Additionally, the runtime checks for
tebd_input_mode
(line 576) and other validations look good.Apply this diff to clarify the comment:
- # Scale the stride values for SE_T descriptor + # Scale the stride values to match SE_T behavior for TEBD compression
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/pt/model/descriptor/se_t_tebd.py
(10 hunks)source/lib/include/tabulate.h
(2 hunks)source/lib/src/gpu/tabulate.cu
(3 hunks)source/lib/src/tabulate.cc
(2 hunks)source/op/pt/tabulate_multi_device.cc
(4 hunks)source/tests/pt/test_tabulate_fusion_se_t_tebd.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .
andruff format .
before committing changes to Python code
Files:
source/tests/pt/test_tabulate_fusion_se_t_tebd.py
deepmd/pt/model/descriptor/se_t_tebd.py
🧬 Code graph analysis (6)
source/lib/include/tabulate.h (2)
source/lib/src/gpu/tabulate.cu (25)
void
(39-48)void
(165-252)void
(255-362)void
(365-471)void
(474-514)void
(517-578)void
(581-631)void
(634-672)void
(675-718)void
(721-761)void
(764-795)void
(798-840)void
(843-882)tabulate_fusion_se_t_tebd_gpu
(1058-1078)tabulate_fusion_se_t_tebd_gpu
(1058-1066)tabulate_fusion_se_t_tebd_gpu
(1393-1401)tabulate_fusion_se_t_tebd_gpu
(1403-1411)tabulate_fusion_se_t_tebd_grad_gpu
(1081-1104)tabulate_fusion_se_t_tebd_grad_gpu
(1081-1090)tabulate_fusion_se_t_tebd_grad_gpu
(1413-1423)tabulate_fusion_se_t_tebd_grad_gpu
(1425-1435)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1107-1132)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1107-1116)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1437-1447)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1449-1459)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu
(545-590)tabulate_fusion_se_t_tebd_cpu
(545-553)tabulate_fusion_se_t_tebd_cpu
(963-972)tabulate_fusion_se_t_tebd_cpu
(973-982)tabulate_fusion_se_t_tebd_grad_cpu
(593-641)tabulate_fusion_se_t_tebd_grad_cpu
(593-602)tabulate_fusion_se_t_tebd_grad_cpu
(983-993)tabulate_fusion_se_t_tebd_grad_cpu
(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu
(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu
(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu
(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu
(1016-1026)
source/lib/src/tabulate.cc (1)
source/lib/src/gpu/tabulate.cu (1)
locate_xx_se_t
(76-83)
source/tests/pt/test_tabulate_fusion_se_t_tebd.py (2)
source/tests/consistent/common.py (1)
parameterized
(580-640)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t
(1193-1201)tabulate_fusion_se_t
(1193-1198)
deepmd/pt/model/descriptor/se_t_tebd.py (6)
source/lib/include/tabulate.h (1)
deepmd
(4-293)deepmd/pt/utils/tabulate.py (1)
DPTabulate
(30-441)deepmd/pt/model/descriptor/se_t.py (3)
enable_compression
(284-330)enable_compression
(743-770)serialize
(412-438)deepmd/pt/utils/utils.py (1)
ActivationFn
(175-220)deepmd/utils/tabulate.py (1)
build
(70-289)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t_tebd
(1203-1211)tabulate_fusion_se_t_tebd
(1203-1208)
source/op/pt/tabulate_multi_device.cc (2)
source/lib/src/gpu/tabulate.cu (12)
tabulate_fusion_se_t_tebd_gpu
(1058-1078)tabulate_fusion_se_t_tebd_gpu
(1058-1066)tabulate_fusion_se_t_tebd_gpu
(1393-1401)tabulate_fusion_se_t_tebd_gpu
(1403-1411)tabulate_fusion_se_t_tebd_grad_gpu
(1081-1104)tabulate_fusion_se_t_tebd_grad_gpu
(1081-1090)tabulate_fusion_se_t_tebd_grad_gpu
(1413-1423)tabulate_fusion_se_t_tebd_grad_gpu
(1425-1435)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1107-1132)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1107-1116)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1437-1447)tabulate_fusion_se_t_tebd_grad_grad_gpu
(1449-1459)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu
(545-590)tabulate_fusion_se_t_tebd_cpu
(545-553)tabulate_fusion_se_t_tebd_cpu
(963-972)tabulate_fusion_se_t_tebd_cpu
(973-982)tabulate_fusion_se_t_tebd_grad_cpu
(593-641)tabulate_fusion_se_t_tebd_grad_cpu
(593-602)tabulate_fusion_se_t_tebd_grad_cpu
(983-993)tabulate_fusion_se_t_tebd_grad_cpu
(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu
(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu
(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu
(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu
(1016-1026)
source/lib/src/gpu/tabulate.cu (1)
source/lib/src/tabulate.cc (2)
locate_xx_se_t
(45-73)locate_xx_se_t
(45-52)
🪛 Ruff (0.13.1)
deepmd/pt/model/descriptor/se_t_tebd.py
552-552: Avoid specifying long messages outside the exception class
(TRY003)
572-574: Avoid specifying long messages outside the exception class
(TRY003)
577-577: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (18)
deepmd/pt/model/descriptor/se_t_tebd.py (5)
191-191
: Initializecompress
state variable.The implementation correctly initializes the compression state flag to
False
.
696-702
: LGTM! Compression metadata storage initialized correctly.The
compress_info
andcompress_data
ParameterLists are properly initialized for storing compression tables and configuration.
954-970
: TEBD compression uses correct tabulation operation.The tabulated TEBD path correctly uses
torch.ops.deepmd.tabulate_fusion_se_t_tebd
with proper tensor reshaping and dimension handling. The compressed embedding computation preserves the full neighbor structure as expected.
1021-1023
: Proper combination of geometric and type embeddings.The combination formula
gg = gg_s * gg_t + gg_s
correctly implements thegg_s * (1 + gg_t)
pattern for merging geometric and type embeddings.
1041-1080
: Block-level compression setup is correct.The
enable_compression
method for the block properly configures the compression metadata using the shared geometric embedding network key "filter_net" and updates the compression state.source/tests/pt/test_tabulate_fusion_se_t_tebd.py (2)
18-20
: Test class properly parameterized and conditionally skipped.The test class correctly uses the
@parameterized
decorator for multiple dtypes and skips when PyTorch customized ops are unavailable.
201-270
: Incorrect shape concern — test uses the 2D op, so the expected (4,4) shape is correct.The test calls torch.ops.deepmd.tabulate_fusion_se_t (test_forward in source/tests/pt/test_tabulate_fusion_se_t_tebd.py), and the C++ wrapper for tabulate_fusion_se_t allocates descriptor as torch::empty({em_tensor.size(0), last_layer_size}) (see source/op/pt/tabulate_multi_device.cc:953–963). The TEBD-specific wrapper tabulate_fusion_se_t_tebd does allocate a 4D tensor (see source/op/pt/tabulate_multi_device.cc:1119–1125), but this test intentionally invokes the 2D variant; expected_descriptor_tensor.reshape(4,4) matches the actual op output.
Likely an incorrect or invalid review comment.
source/lib/src/tabulate.cc (4)
544-590
: TEBD forward implementation preserves full neighbor structure.The implementation correctly preserves the nt_i x nt_j x ng structure for SE_T_TEBD, which differs from SE_T's reduction pattern. The polynomial evaluation and output indexing are correct.
592-641
: TEBD gradient computation correctly accumulates over last_layer_size.The gradient implementation properly accumulates gradients across all last_layer_size dimensions and stores the result in the correct index pattern.
643-692
: TEBD grad-grad implementation correctly applies chain rule.The second-order gradient computation properly multiplies the incoming gradient with the derivative of the polynomial.
963-1026
: Template instantiations properly added for TEBD functions.All six TEBD functions (forward, grad, grad_grad) are correctly instantiated for both float and double types, matching the pattern of existing tabulation functions.
source/lib/include/tabulate.h (2)
114-147
: CPU function declarations follow existing patterns.The three new CPU function declarations for TEBD tabulation follow the established pattern and parameter ordering conventions.
258-291
: GPU function declarations properly guarded and consistent.The GPU function declarations are correctly placed within the CUDA/ROCM preprocessor guards and maintain consistency with the CPU declarations.
source/op/pt/tabulate_multi_device.cc (5)
338-385
: Forward function implementation looks good.The
TabulateFusionSeTTebdForward
function correctly validates input dimensions, handles device dispatch, and calls the appropriate CPU/GPU kernels.
387-431
: Gradient forward function has incorrect dimension check.At line 396, the function checks that
dy_tensor.dim() != 4
, but this seems incorrect. For SE_T_TEBD with shape[nloc, nnei_i, nnei_j, last_layer_size]
, the gradient should indeed be 4D.The dimension check appears correct - the gradient tensor dy should match the descriptor tensor's 4D shape for SE_T_TEBD.
1086-1168
: Autograd operator correctly implements forward and backward passes.The
TabulateFusionSeTTebdOp
class properly:
- Allocates 4D descriptor tensors with correct dimensions
- Saves necessary tensors for backward pass
- Computes gradients only for em_x_tensor (not em_tensor)
1203-1211
: Public API function properly exposed.The
tabulate_fusion_se_t_tebd
function correctly wraps the autograd operator.
1231-1233
: Torch library registration complete.The TEBD operation is properly registered with the Torch library.
for (int ii = 0; ii < nnei_i; ii++) { | ||
for (int jj = 0; jj < nnei_j; jj++) { | ||
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; | ||
int table_idx = 0; | ||
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); | ||
|
||
FPTYPE grad_sum = 0.0; | ||
for (int mm = 0; mm < last_layer_size; mm++) { | ||
FPTYPE var[6]; | ||
load_polynomial_params(var, table, table_idx, mm, last_layer_size); | ||
|
||
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx + | ||
4.0 * var[4] * xx * xx * xx + | ||
5.0 * var[5] * xx * xx * xx * xx; | ||
|
||
FPTYPE dy_val = | ||
dy[block_idx * nnei_i * nnei_j * last_layer_size + | ||
ii * nnei_j * last_layer_size + jj * last_layer_size + mm]; | ||
grad_sum += dy_val * dres_dxx; | ||
} | ||
|
||
if (thread_idx == 0) { // Only one thread writes the gradient | ||
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid 128× redundant compute in TEBD grad kernel
Every thread walks the full nnei_i * nnei_j * last_layer_size
loops, yet only thread 0 ever writes to dy_dem_x
. As a result the whole block repeats the same math blockDim.x
times (≈128× with the current launch configuration), which is a major performance regression and risks triggering kernel timeouts on bigger tables. Gate the body to a single lane (or properly distribute the work across threads); the minimal fix below stops the duplication.
const int_64 block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // thread within block
+ if (thread_idx != 0) {
+ return;
+ }
+
for (int ii = 0; ii < nnei_i; ii++) {
for (int jj = 0; jj < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
@@
- if (thread_idx == 0) { // Only one thread writes the gradient
- dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum;
- }
+ dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum;
}
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for (int ii = 0; ii < nnei_i; ii++) { | |
for (int jj = 0; jj < nnei_j; jj++) { | |
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; | |
int table_idx = 0; | |
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); | |
FPTYPE grad_sum = 0.0; | |
for (int mm = 0; mm < last_layer_size; mm++) { | |
FPTYPE var[6]; | |
load_polynomial_params(var, table, table_idx, mm, last_layer_size); | |
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx + | |
4.0 * var[4] * xx * xx * xx + | |
5.0 * var[5] * xx * xx * xx * xx; | |
FPTYPE dy_val = | |
dy[block_idx * nnei_i * nnei_j * last_layer_size + | |
ii * nnei_j * last_layer_size + jj * last_layer_size + mm]; | |
grad_sum += dy_val * dres_dxx; | |
} | |
if (thread_idx == 0) { // Only one thread writes the gradient | |
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum; | |
} | |
} | |
const int_64 block_idx = blockIdx.x; // nloc | |
const int thread_idx = threadIdx.x; // thread within block | |
// Only one thread needs to do the work; everyone else bails out immediately. | |
if (thread_idx != 0) { | |
return; | |
} | |
for (int ii = 0; ii < nnei_i; ii++) { | |
for (int jj = 0; jj < nnei_j; jj++) { | |
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; | |
int table_idx = 0; | |
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); | |
FPTYPE grad_sum = 0.0; | |
for (int mm = 0; mm < last_layer_size; mm++) { | |
FPTYPE var[6]; | |
load_polynomial_params(var, table, table_idx, mm, last_layer_size); | |
FPTYPE dres_dxx = var[1] | |
2.0 * var[2] * xx | |
3.0 * var[3] * xx * xx | |
4.0 * var[4] * xx * xx * xx | |
5.0 * var[5] * xx * xx * xx * xx; | |
FPTYPE dy_val = | |
dy[block_idx * nnei_i * nnei_j * last_layer_size | |
ii * nnei_j * last_layer_size | |
jj * last_layer_size | |
mm]; | |
grad_sum += dy_val * dres_dxx; | |
} | |
// Unconditionally write now that only thread 0 is left. | |
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum; | |
} | |
} |
# Call the forward function | ||
forward_result = torch.ops.deepmd.tabulate_fusion_se_t( | ||
self.table_tensor, | ||
self.table_info_tensor, | ||
self.em_x_tensor, | ||
self.em_tensor, | ||
self.last_layer_size, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue: backward test uses SE_T operation instead of SE_T_TEBD.
The backward test also uses the wrong operation.
Apply this diff to use the correct operation:
- forward_result = torch.ops.deepmd.tabulate_fusion_se_t(
+ forward_result = torch.ops.deepmd.tabulate_fusion_se_t_tebd(
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Call the forward function | |
forward_result = torch.ops.deepmd.tabulate_fusion_se_t( | |
self.table_tensor, | |
self.table_info_tensor, | |
self.em_x_tensor, | |
self.em_tensor, | |
self.last_layer_size, | |
) | |
# Call the forward function | |
forward_result = torch.ops.deepmd.tabulate_fusion_se_t_tebd( | |
self.table_tensor, | |
self.table_info_tensor, | |
self.em_x_tensor, | |
self.em_tensor, | |
self.last_layer_size, | |
) |
🤖 Prompt for AI Agents
In source/tests/pt/test_tabulate_fusion_se_t_tebd.py around lines 343 to 350,
the backward test is calling torch.ops.deepmd.tabulate_fusion_se_t instead of
the TEBD variant; update the test to call
torch.ops.deepmd.tabulate_fusion_se_t_tebd for the backward phase (replace the
incorrect operation name with tabulate_fusion_se_t_tebd where the backward call
is made) so the backward test uses the correct SE_T_TEBD operator.
TORCH_LIBRARY_FRAGMENT(deepmd, m) { | ||
m.def("tabulate_fusion_se_t", tabulate_fusion_se_t); | ||
} | ||
TORCH_LIBRARY_FRAGMENT(deepmd, m) { |
Check notice
Code scanning / CodeQL
Unused static function Note
TORCH_LIBRARY_FRAGMENT_static_init_deepmd_5
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…abulate_fusion_se_t_tebd custom OP but not used and tested
for more information, see https://pre-commit.ci
c252e57
to
cea01d2
Compare
add compression support for se_e3_tebd
Summary by CodeRabbit
New Features
Performance Improvements
Reliability
Tests