Currently, the TE tests as implemented in _test_te.yaml consists of two parts:
- unit testing on V100 only
- multi-GPU testing on A100 via SLURM jobs.
This seems to make the testing setup unnecessarily complex.
We have a V100/A100 unit testing framework as exemplified in _test_jax.yaml, which allows the same unit testing/multi-GPU test logic to be matrices over GPU types as well as scaling from 1-8 GPUs.
@terrykong @ashors1 would you be able to refactor the TE to follow the JAX unit testing framework?