Skip to content

Conversation

@Steboss
Copy link
Contributor

@Steboss Steboss commented May 9, 2025

This PR does the following:

  • remove array_serialization_test.py that's causing the test to hang with the following error:
[pod/axlearn-14925960362-xb57t/axlearn]   File "/opt/jax/jax/experimental/array_serialization/serialization.py", line 193, in __del__
[pod/axlearn-14925960362-xb57t/axlearn]     logger.warning('Please add `.wait_until_finished()` in the main thread '
[pod/axlearn-14925960362-xb57t/axlearn] Message: 'Please add `.wait_until_finished()` in the main thread before your program finishes because there is a possibility of losing errors raised if the this class is deleted before writing is completed.'
[pod/axlearn-14925960362-xb57t/axlearn] Arguments: ()
[pod/axlearn-14925960362-xb57t/axlearn] sssssssssssssss.ssssssssssss.ssssssssssssssssssssssFsFs.s...sssssssFs.Fs [ 97%]

causing the EKS job to run out of time - so we can't get the tests

  • remove tests that are redundant (namely, there are similar test already running) such as:
"/opt/axlearn/axlearn/common/deberta_test.py"
"/opt/axlearn/axlearn/common/distilbert_test.py"
"/opt/axlearn/axlearn/common/trainer_test.py"
"/opt/axlearn/axlearn/common/decoder_test.py"
"/opt/axlearn/axlearn/common/adapter_torch_test.py"
"/opt/axlearn/axlearn/common/attention_test.py"
"/opt/axlearn/axlearn/common/convolution_test.py"
  • remove tests for models that we're not currently using:
"/opt/axlearn/axlearn/common/mixture_of_experts_test.py"
"/opt/axlearn/axlearn/common/t5_test.py"
"/opt/axlearn/axlearn/common/vision_transformer_test.py"
"/opt/axlearn/axlearn/common/input_reading_comprehension_test.py"
"/opt/axlearn/axlearn/common/input_t5_test.py"
  • remove tests like summary_writer_test.py that is mostly using python library we're not employing here (e.g. wandb)
  • add the installation of pytest-xdist and pytest-reportlog to avoid the following error:
ERROR: usage: pytest [options] [file_or_dir] [file_or_dir] [...]
pytest: error: unrecognized arguments: --report-log=/tmp/tmp.iL7rVQtXBq --dist=load --tx --tx 

Overall, this should allow us to reduce the testing time from 50 minutes to 30 minutes, covering the most important tests as well, that are dealing with general AXLearn infrastructure.

Copy link
Collaborator

@olupton olupton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see an error that seems like a test/infra bug, rather than a failing test: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/14931992154/job/41952767038#step:6:41440

Some of the tests look like they are failing due to missing input data.

Also, the CI job is marked as successful despite the tests failing.

@Steboss
Copy link
Contributor Author

Steboss commented May 13, 2025

In this PR I added a workflow dispatch, so we can test single parts of the CI.
In particular, this may be a better trick to be used during CI, so that we can have MODE=SOMETHING to trigger only specific parts of the workflow, rather than testing everything

@Steboss Steboss requested a review from olupton May 13, 2025 16:41
@Steboss
Copy link
Contributor Author

Steboss commented May 14, 2025

@olupton
I can see that the axlearn tests now are working fine. We have some tests that are still failing - I can have a look at those.
Working on the way we're monitoring and returning the k8s job status, as it's still giving green flag on the axlearn eks job.

@Steboss
Copy link
Contributor Author

Steboss commented May 14, 2025

@olupton
It looks like we need the XLA_FLAGS="--xla_force_host_platform_device_count=8" for the for_8_devices tests, otherwise the XLA tests will fail as:

[pod/axlearn-15024599931-2vfmc/axlearn] ___________________ HostArrayTest.test_fixed_process_shape67 ___________________
[pod/axlearn-15024599931-2vfmc/axlearn] [gw93] linux -- Python 3.12.3 /usr/bin/python3
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] self = <axlearn.common.host_array_test.HostArrayTest testMethod=test_fixed_process_shape67>
[pod/axlearn-15024599931-2vfmc/axlearn] platform = 'cpu', mesh_shape = (-1, 2), process_shape = [1]
[pod/axlearn-15024599931-2vfmc/axlearn] partition = PartitionSpec('data', 'model')
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn]     @parameterized.product(
[pod/axlearn-15024599931-2vfmc/axlearn]         platform=("cpu", "tpu"),
[pod/axlearn-15024599931-2vfmc/axlearn]         mesh_shape=[
[pod/axlearn-15024599931-2vfmc/axlearn]             (-1, 1),  # Fully partitioned along one dim.
[pod/axlearn-15024599931-2vfmc/axlearn]             (2, -1),  # Partitioned along multiple dims.
[pod/axlearn-15024599931-2vfmc/axlearn]             (-1, 2),  # Test the other way.
[pod/axlearn-15024599931-2vfmc/axlearn]             (1, -1),
[pod/axlearn-15024599931-2vfmc/axlearn]         ],
[pod/axlearn-15024599931-2vfmc/axlearn]         process_shape=[
[pod/axlearn-15024599931-2vfmc/axlearn]             # Each process produces single dim.
[pod/axlearn-15024599931-2vfmc/axlearn]             [1],  # Not divisible by number of devices (replicated).
[pod/axlearn-15024599931-2vfmc/axlearn]             [8],  # Divisible by number of devices.
[pod/axlearn-15024599931-2vfmc/axlearn]             [16],  # Multiple elements per device.
[pod/axlearn-15024599931-2vfmc/axlearn]             # Each process produces multiple dims.
[pod/axlearn-15024599931-2vfmc/axlearn]             [1, 1],  # Not divisible by number of devices (replicated).
[pod/axlearn-15024599931-2vfmc/axlearn]             [2, 1],  # Can be partitioned over dim=0, replicated on dim=1.
[pod/axlearn-15024599931-2vfmc/axlearn]             [16, 1],  # Multiple elements per device.
[pod/axlearn-15024599931-2vfmc/axlearn]             [2, 4],  # Can be fully partitioned.
[pod/axlearn-15024599931-2vfmc/axlearn]             [8, 8],  # Can be fully partitioned.
[pod/axlearn-15024599931-2vfmc/axlearn]         ],
[pod/axlearn-15024599931-2vfmc/axlearn]         partition=(
[pod/axlearn-15024599931-2vfmc/axlearn]             DataPartitionType.FULL,
[pod/axlearn-15024599931-2vfmc/axlearn]             DataPartitionType.REPLICATED,
[pod/axlearn-15024599931-2vfmc/axlearn]             PartitionSpec("data"),
[pod/axlearn-15024599931-2vfmc/axlearn]             PartitionSpec("data", "model"),
[pod/axlearn-15024599931-2vfmc/axlearn]         ),
[pod/axlearn-15024599931-2vfmc/axlearn]     )
[pod/axlearn-15024599931-2vfmc/axlearn]     # NOTE: while annotated with `for_8_devices`, this runs on other configurations.
[pod/axlearn-15024599931-2vfmc/axlearn]     @pytest.mark.for_8_devices
[pod/axlearn-15024599931-2vfmc/axlearn]     def test_fixed_process_shape(
[pod/axlearn-15024599931-2vfmc/axlearn]         self,
[pod/axlearn-15024599931-2vfmc/axlearn]         platform: str,
[pod/axlearn-15024599931-2vfmc/axlearn]         mesh_shape: tuple[int, int],
[pod/axlearn-15024599931-2vfmc/axlearn]         process_shape: Sequence[int],
[pod/axlearn-15024599931-2vfmc/axlearn]         partition: Union[DataPartitionType, PartitionSpec],
[pod/axlearn-15024599931-2vfmc/axlearn]     ):
[pod/axlearn-15024599931-2vfmc/axlearn]         """Tests roundtrip host-to-global and global-to-host with fixed process shape."""
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn] >       mesh_shape = infer_mesh_shape(mesh_shape)
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] axlearn/common/host_array_test.py:124: 
[pod/axlearn-15024599931-2vfmc/axlearn] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] mesh_shape = (-1, 2)
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn]     def infer_mesh_shape(mesh_shape: MeshShape, *, num_devices: Optional[int] = None) -> MeshShape:
[pod/axlearn-15024599931-2vfmc/axlearn]         """Infer the value for -1 from len(jax.devices()) and other dims if there is -1 in mesh shape.
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         Args:
[pod/axlearn-15024599931-2vfmc/axlearn]             mesh_shape: The original MeshShape, which might have -1 in one axis.
[pod/axlearn-15024599931-2vfmc/axlearn]             num_devices: The devices that will be used to construct the mesh.
[pod/axlearn-15024599931-2vfmc/axlearn]                 If None, defaults to len(jax.devices()).
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         Returns
[pod/axlearn-15024599931-2vfmc/axlearn]             A new MeshShape with inferred value for -1.
[pod/axlearn-15024599931-2vfmc/axlearn]         """
[pod/axlearn-15024599931-2vfmc/axlearn]         if -1 not in mesh_shape:
[pod/axlearn-15024599931-2vfmc/axlearn]             return mesh_shape
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         if mesh_shape.count(-1) > 1:
[pod/axlearn-15024599931-2vfmc/axlearn]             raise ValueError(f"Only one axis can be -1 in {mesh_shape=}.")
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         # Handle the case with one -1.
[pod/axlearn-15024599931-2vfmc/axlearn]         prod = math.prod(mesh_shape, start=-1)
[pod/axlearn-15024599931-2vfmc/axlearn]         if num_devices is None:
[pod/axlearn-15024599931-2vfmc/axlearn]             num_devices = len(jax.devices())
[pod/axlearn-15024599931-2vfmc/axlearn]         if num_devices % prod != 0:
[pod/axlearn-15024599931-2vfmc/axlearn] >           raise ValueError(
[pod/axlearn-15024599931-2vfmc/axlearn]                 f"Unable to infer -1 in mesh shape {mesh_shape} as num_devices {num_devices} "
[pod/axlearn-15024599931-2vfmc/axlearn]                 f"is not a multiple of the product {prod} of mesh axes."
[pod/axlearn-15024599931-2vfmc/axlearn]             )
[pod/axlearn-15024599931-2vfmc/axlearn] E           ValueError: Unable to infer -1 in mesh shape (-1, 2) as num_devices 1 is not a multiple of the product 2 of mesh axes.
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] axlearn/common/utils.py:1834: ValueError

I ran a test, that results in:

  • with flag: 1 failed, 161 passed
  • without flag: 203 failed, 100 passed

@Steboss
Copy link
Contributor Author

Steboss commented May 14, 2025

This may be a new JAX version error

AttributeError: module 'jax.experimental.array_serialization.serialization' has no attribute '_spec_has_metadata'

@Steboss Steboss requested a review from olupton May 19, 2025 16:39
olupton
olupton previously approved these changes May 20, 2025
Copy link
Collaborator

@olupton olupton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can merge this to speed up the pipeline, but I left some more comments on error handling and robustness.

@Steboss Steboss requested a review from olupton May 21, 2025 12:32
Copy link
Collaborator

@olupton olupton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can merge this as an improvement to the current situation.

The job is still marked green despite test errors/failures, although they do seem to be captured in the badge. This should be fixed in a follow-up.

@Steboss Steboss merged commit 1ea9d31 into main May 22, 2025
66 of 74 checks passed
@Steboss Steboss deleted the sbosisio/fix_axlearn_tests branch May 22, 2025 09:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants