Skip to content

Commit 7433d00

Browse files
authored
Merge branch 'main' into alechan/add-te
2 parents 6a4cd59 + 1b6845a commit 7433d00

File tree

10 files changed

+1548
-38
lines changed

10 files changed

+1548
-38
lines changed

.github/container/Dockerfile.base

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# syntax=docker/dockerfile:1-labs
2-
ARG BASE_IMAGE=nvcr.io/nvidia/cuda-dl-base:25.01-cuda12.8-devel-ubuntu24.04
2+
ARG BASE_IMAGE=nvcr.io/nvidia/cuda-dl-base:25.02-cuda12.8-devel-ubuntu24.04
33
ARG GIT_USER_NAME="JAX Toolbox"
44
55
ARG CLANG_VERSION=18
@@ -42,6 +42,7 @@ apt_packages=(
4242
vim
4343
wget
4444
jq
45+
zip
4546
# llvm.sh
4647
lsb-release
4748
software-properties-common

.github/container/Dockerfile.maxtext

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,18 @@ for pattern in \
2222
"s|absl-py|absl-py>=2.1.0|g" \
2323
"s|protobuf==3.20.3|protobuf>=3.19.0|g" \
2424
"s|tensorflow-datasets|tensorflow-datasets>=4.8.0|g" \
25+
"s|sentencepiece==0.1.97|sentencepiece>=0.2|g" \
2526
; do
2627
sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/requirements.txt;
2728
done
28-
echo -e "\ntensorflow-metadata>=1.15.0" >> ${SRC_PATH_MAXTEXT}/requirements.txt
29+
# add new line in case requirements.txt does not end with a new line
30+
echo >> ${SRC_PATH_MAXTEXT}/requirements.txt
31+
for requirement in \
32+
"tensorflow-metadata>=1.15.0" \
33+
"seqio@git+https://github.com/google/seqio.git" \
34+
; do
35+
echo "${requirement}" >> ${SRC_PATH_MAXTEXT}/requirements.txt
36+
done
2937
EOF
3038

3139
###############################################################################

.github/container/nsys_jax/nsys_jax/data_loaders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,9 @@ def _load_nvtx_gpu_proj_trace_single(
230230
mod_id_names = df.loc[mod_ids, "Name"]
231231
assert mod_ids.shape == mod_id_names.shape
232232
# Get a mask in mod_id_names of entries where ModuleId in the original
233-
# Thunk is not referring to a Module. If it's not a module, it should
234-
# be a thunk.
233+
# Thunk is not referring to a Module yet. Intermediate levels of the
234+
# hierarchy can be other thunks (e.g. an individual graph node may
235+
# have a thunk representing the whole graph as a parent).
235236
mask = ~mod_id_names.str.startswith(module_prefix)
236237
assert (mask == mod_id_names.str.startswith(thunk_prefix)).all()
237238
assert mask.shape == mod_ids.shape

.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ def gather_source_files(
574574
if src_file == "<string>":
575575
# This can appear due to python -c "...", for example.
576576
continue
577+
if src_file == "<frozen runpy>":
578+
continue
577579
assert osp.isabs(src_file), f"{src_file} is not absolute"
578580
output_queue.put(("sources" + src_file, src_file, COMPRESS_DEFLATE))
579581
print(f"{archive_name}: gathered source code in {time.time() - start:.2f}s")

0 commit comments

Comments
 (0)