Skip to content

Commit 4eeff71

Browse files
committed
Merge remote-tracking branch 'origin/main' into 24.10-devel-add-ngs-release-testing
2 parents 0a09276 + ccededf commit 4eeff71

File tree

4 files changed

+57
-21
lines changed

4 files changed

+57
-21
lines changed

.github/container/Dockerfile.base

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# syntax=docker/dockerfile:1-labs
2-
ARG BASE_IMAGE=nvidia/cuda:12.5.0-devel-ubuntu22.04
2+
ARG BASE_IMAGE=nvidia/cuda:12.6.1-devel-ubuntu22.04
33
ARG GIT_USER_NAME="JAX Toolbox"
44
5-
ARG CLANG_VERSION=17
5+
ARG CLANG_VERSION=18
66

77
###############################################################################
88
## Obtain GCP's NCCL TCPx plugin

.github/container/test-pax.sh

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ usage() {
1515
echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py"
1616
echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4."
1717
echo " --dtype Batch size, defaults to bfloat16."
18-
echo " --enable-te If set, will run with env var ENABLE_TE=1."
18+
echo " --enable-te If set, will run with env var ENABLE_TE=1."
19+
echo " --enable-cudnn-fa If set, will use cudnn fa."
1920
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
2021
echo " --disable-fused-attn Whether disable TE fused attention."
2122
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
@@ -26,13 +27,13 @@ usage() {
2627
echo " --data-parallel Data parallelism to use. Defaults to 1."
2728
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
2829
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
29-
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
30+
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
3031
echo " -n, --nodes Number of nodes."
3132
echo " -h, --help Print usage."
3233
exit $1
3334
}
3435

35-
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
36+
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
3637
if [[ $? -ne 0 ]]; then
3738
exit $1
3839
fi
@@ -50,6 +51,7 @@ TP=1
5051
PP=1
5152
NODES=1
5253
ENABLE_TE=0
54+
ENABLE_CUDNN_FA=0
5355
MODEL_TYPE=126M
5456
NVTE_FUSED_ATTN=1
5557
DROPOUT=0
@@ -75,6 +77,10 @@ while [ : ]; do
7577
ENABLE_TE=1
7678
shift 1
7779
;;
80+
--enable-cudnn-fa)
81+
ENABLE_CUDNN_FA=1
82+
shift 1
83+
;;
7884
--enable-dropout)
7985
DROPOUT='0.1'
8086
shift 1
@@ -128,7 +134,7 @@ while [ : ]; do
128134
;;
129135
--)
130136
shift;
131-
break
137+
break
132138
;;
133139
*)
134140
echo "UNKNOWN OPTION $1"
@@ -149,6 +155,7 @@ print_var NGPUS
149155
print_var OUTPUT
150156
print_var MULTIPROCESS
151157
print_var ENABLE_TE
158+
print_var ENABLE_CUDNN_FA
152159
print_var NVTE_FUSED_ATTN
153160
print_var EVALUATE
154161
print_var DROPOUT
@@ -196,10 +203,10 @@ if dcn_factor > 1:
196203
if dp % dcn_factor == 0:
197204
dcn_dp = dcn_factor
198205
dp = int(dp / dcn_factor)
199-
elif fsdp % dcn_factor == 0:
206+
elif fsdp % dcn_factor == 0:
200207
dcn_fsdp = dcn_factor
201208
fsdp = int(fsdp / dcn_factor)
202-
elif pp % dcn_factor == 0:
209+
elif pp % dcn_factor == 0:
203210
dcn_pp = dcn_factor
204211
pp = int(pp / dcn_factor)
205212
@@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
209216
USE_REPEATED_LAYER = False
210217
ICI_MESH_SHAPE = [64,1,1]
211218
MAX_STEPS = 600000
212-
219+
213220
MAX_SEQ_LEN = 2048
214221
VOCAB_SIZE = 50304
215222
PACKED_INPUT = True
216223
PERCORE_BATCH_SIZE = 4
217-
224+
218225
NUM_LAYERS = 12
219226
NUM_HEADS = 12
220227
MODEL_DIMS = 768
@@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
223230
224231
TRAINABLE_POSITION_EMB = True
225232
TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN
226-
233+
227234
USE_BIAS = True
228235
LAYERNORM_EPSILON = 1e-5
229236
ATTEN_LOGIT_CAP = -1.0
230237
INIT_STD = 0.023
231238
SOFTMAX_INIT_STD = 0.023
232239
ACTIVATION_CLS = layers.GELU
233-
240+
234241
## optimizer-related
235242
ADAM_BETA1 = 0.9
236243
ADAM_BETA2 = 0.95
@@ -255,15 +262,15 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
255262
## disable eval to avoid including eval
256263
## in steps/sec calculation
257264
EVAL_INTERVAL_STEPS = 100000
258-
265+
259266
def task(self):
260267
task_p = super().task()
261268
task_p = configure_gpt3_task(self, task_p)
262269
263270
task_p.train.num_train_steps = self.MAX_STEPS
264271
265272
model_p = task_p.model
266-
273+
267274
### compute layernorm reductions in fp32. Needed for stable training on GPUs
268275
stacked_p = model_p.lm_tpl.stacked_transformer_tpl
269276
if stacked_p.cls == layers.PipelinedTransformer:
@@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
274281
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
275282
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True
276283
task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True
277-
284+
278285
model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
279286
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
280287
model_p.lm_tpl.softmax_tpl.params_init = softmax_init
281-
288+
282289
model_p.apply_eval_sample_weights = True
283-
290+
284291
## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0
285292
stacked_p.dropout_prob = 0.0
286293
stacked_p.input_dropout_prob = self.DROPOUT_PROB
@@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset):
316323
if pp > 1:
317324
@experiment_registry.register
318325
class Synthetic126MCI(GPT126MPP, SyntheticDataset):
319-
326+
320327
ICI_MESH_SHAPE = [pp, dp, fsdp, tp]
321328
DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1]
322329
MICROBATCH_SIZE = 2
323330
NUM_STAGES = pp
324331
PERCORE_BATCH_SIZE = percore_batch_size
325332
FRPOP_DTYPE = dtype
326-
333+
327334
def task(self):
328335
task_p = super().task()
329336
task_p.train.always_use_train_for_model_init=False
@@ -333,7 +340,7 @@ if pp > 1:
333340
else:
334341
@experiment_registry.register
335342
class Synthetic126MCI(Synthetic126M):
336-
343+
337344
ICI_MESH_SHAPE = [dp, fsdp, tp]
338345
DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1]
339346
PERCORE_BATCH_SIZE = percore_batch_size
@@ -343,7 +350,7 @@ else:
343350
344351
## disable eval
345352
EVAL_INTERVAL_STEPS = 100000
346-
353+
347354
def task(self):
348355
task_p = super().task()
349356
@@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE
374381
export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN
375382
export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model}
376383

384+
if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then
385+
ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True"
386+
fi
387+
377388
if [[ ${MODEL_TYPE} == "126M" ]]; then
378389
CONFIG=ci_configs.Synthetic126MCI
379390
elif [[ ${MODEL_TYPE} == "5B" ]]; then

.github/workflows/_ci.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ on:
1111
description: 'Build date in YYYY-MM-DD format'
1212
required: false
1313
default: NOT SPECIFIED
14+
CUDA_IMAGE:
15+
type: string
16+
description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04
17+
default: 'latest'
18+
required: false
1419
MANIFEST_ARTIFACT_NAME:
1520
type: string
1621
description: 'Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch'
@@ -37,6 +42,7 @@ jobs:
3742
uses: ./.github/workflows/_build_base.yaml
3843
with:
3944
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
45+
BASE_IMAGE: ${{ inputs.CUDA_IMAGE }}
4046
BUILD_DATE: ${{ inputs.BUILD_DATE }}
4147
MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }}
4248
secrets: inherit

.github/workflows/ci.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ on:
2828
description: "(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch"
2929
default: false
3030
required: false
31+
CUDA_IMAGE:
32+
type: string
33+
description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04
34+
default: 'latest'
35+
required: false
3136
SOURCE_OVERRIDES:
3237
type: string
3338
description: |
@@ -60,6 +65,7 @@ jobs:
6065
MANIFEST_ARTIFACT_NAME: ${{ steps.manifest-branch.outputs.MANIFEST_ARTIFACT_NAME }}
6166
MANIFEST_BRANCH: ${{ steps.manifest-branch.outputs.MANIFEST_BRANCH }}
6267
MERGE_BUMPED_MANIFEST: ${{ steps.manifest-branch.outputs.MERGE_BUMBED_MANIFEST }}
68+
CUDA_IMAGE: ${{ steps.cuda-image.outputs.CUDA_IMAGE }}
6369
steps:
6470
- name: Cancel workflow run if the trigger is a draft PR
6571
id: cancel-if-draft
@@ -114,6 +120,17 @@ jobs:
114120
exit 1
115121
fi
116122
123+
- name: Determine CUDA image to use
124+
id: cuda-image
125+
shell: bash -x -e {0}
126+
run: |
127+
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
128+
CUDA_IMAGE="${{ inputs.CUDA_IMAGE }}"
129+
else
130+
CUDA_IMAGE="latest"
131+
fi
132+
echo "CUDA_IMAGE=${CUDA_IMAGE}" >> $GITHUB_OUTPUT
133+
117134
bump-manifest:
118135
needs: metadata
119136
runs-on: ubuntu-22.04
@@ -177,6 +194,7 @@ jobs:
177194
with:
178195
ARCHITECTURE: amd64
179196
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
197+
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
180198
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
181199
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
182200
secrets: inherit
@@ -187,6 +205,7 @@ jobs:
187205
with:
188206
ARCHITECTURE: arm64
189207
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
208+
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
190209
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
191210
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
192211
secrets: inherit

0 commit comments

Comments
 (0)