Skip to content

Commit 9983cb0

Browse files
authored
Allow additional args to be passed many times
1 parent 70f67a8 commit 9983cb0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

.github/container/test-maxtext.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ usage() {
1313
echo "Usage: $0 [OPTIONS]"
1414
echo ""
1515
echo " OPTIONS DESCRIPTION"
16-
echo " -a, --additional-args Additional args to pass to MaxText/train.py"
16+
echo " -a, --additional-args Additional args to pass to MaxText/train.py. Can be passed many times."
1717
echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65". Default to 0.90, contradicting JAX default of 0.75.
1818
echo " --model-name Specify the model names to run [Preferred]. If you specify model name then you do not need to specify decoder-block. Currently supported ootb models:
1919
gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-13b, llama2-70b, llama2-7b, llama3-70b, llama3-8b, mistral-7b, mixtral-8x7b"
@@ -34,7 +34,7 @@ usage() {
3434
1. test-maxtext.sh -b 2 --model-name=gpt3-52k
3535
2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3636
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess
37-
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false
37+
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a "scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false"
3838
5. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --dtype=fp8 --steps=10 --fsdp=8 --output train_output --multiprocess
3939
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
4040
7. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
@@ -76,7 +76,7 @@ eval set -- "$args"
7676
while [ : ]; do
7777
case "$1" in
7878
-a | --additional-args)
79-
ADDITIONAL_ARGS="$2"
79+
ADDITIONAL_ARGS="$ADDITIONAL_ARGS $2"
8080
shift 2
8181
;;
8282
--mem-fraction)

0 commit comments

Comments
 (0)