Skip to content

Commit 046ef69

Browse files
muyangyuapplechanglan
authored andcommitted
Support ct6e-standard-8t
GitOrigin-RevId: 217890c603e3b59b6f77ffd8141291d2884065b0
1 parent 21b90ab commit 046ef69

File tree

8 files changed

+30
-5
lines changed

8 files changed

+30
-5
lines changed

axlearn/cloud/gcp/jobs/launch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ def submit(self) -> JobSpec:
402402
"""Submits the command to bastion."""
403403
cfg: BaseBastionManagedJob.Config = self.config
404404
self._bundler.bundle(cfg.name)
405-
406405
logging.info("Starting run for job name %s", cfg.name)
407406
logging.info("Command: %s", cfg.command)
408407
with tempfile.NamedTemporaryFile("w") as f:

axlearn/cloud/gcp/jobset_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def _build_container(self) -> Nested[Any]:
479479
if cfg.enable_tpu_ici_resiliency is not None:
480480
env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower()
481481

482+
# This label will be used by TPU provisioner to select machine type.
482483
resources = {"limits": {"google.com/tpu": system.chips_per_vm}}
483484
# Set request memory by host machine type.
484485
machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(

axlearn/cloud/gcp/node_pool_provisioner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,15 @@ def create_for(self, job: GKEJob):
153153
additional_labels_list.append(additional_labels)
154154

155155
start_time = time.perf_counter()
156-
topology = None if isinstance(job, _INFERENCE_JOBS) else job_sys_property.topology
156+
topology = job_sys_property.topology
157+
if job_sys_property.gce_machine_type == "ct6e-standard-8t":
158+
# If we customize chips_per_vm to use ct6e-standard-8t for v6e,
159+
# it is required not to set topology.
160+
topology = None
161+
if isinstance(job, _INFERENCE_JOBS):
162+
# Inference jobs like Flink/Beam jobs use node pool as single
163+
# host nodes, we don't set topology for them
164+
topology = None
157165
create_node_pools(
158166
node_pool_names,
159167
project=cfg.project,

axlearn/cloud/gcp/system_characteristics.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,12 @@ class _SystemCharacteristics:
395395
"v6e-8": _SystemCharacteristics(
396396
"2x4", 2, "tpu-v6e-slice", "ct6e-standard-4t", 4, AcceleratorType["TPU"], "v6e-8"
397397
),
398+
# Naming convention for TPU: {version}-{cores}[-{variant}]
399+
# "-{variant}" is optional. It is used to define a spec that is different from standard.
400+
# The value can be anything as long as it is unique.
401+
"v6e-8-1": _SystemCharacteristics(
402+
"2x4", 1, "tpu-v6e-slice", "ct6e-standard-8t", 8, AcceleratorType["TPU"], "v6e-8"
403+
),
398404
"v6e-16": _SystemCharacteristics(
399405
"4x4", 4, "tpu-v6e-slice", "ct6e-standard-4t", 4, AcceleratorType["TPU"], "v6e-16"
400406
),

axlearn/cloud/gcp/tpu.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,29 @@ def infer_tpu_cores(tpu_type: str) -> int:
3535
"""Infer the number of TPU cores from the TPU type.
3636
3737
Args:
38-
tpu_type: A string of the format {version}-{cores}.
38+
tpu_type: A string of the format {version}-{cores}[-{variant}].
39+
-variant is optional.
3940
4041
Returns:
4142
Inferred number of TPU cores.
4243
"""
44+
if tpu_type.count("-") == 2:
45+
tpu_type = tpu_type[: tpu_type.rfind("-")]
4346
return int(tpu_type.rsplit("-", 1)[1])
4447

4548

4649
def infer_tpu_workers(tpu_type: str) -> int:
4750
"""Infer the number of worker processes for the given TPU type.
4851
4952
Args:
50-
tpu_type: A string of the format {version}-{cores}.
53+
tpu_type: A string of the format {version}-{cores}[-{variant}].
54+
-variant is optional.
5155
5256
Returns:
5357
Inferred number of TPU workers.
5458
"""
59+
if tpu_type.count("-") == 2:
60+
tpu_type = tpu_type[: tpu_type.rfind("-")]
5561
tpu_pattern = r"(.+)*-(\d+)"
5662
match = re.search(tpu_pattern, tpu_type)
5763
try:

axlearn/cloud/gcp/tpu_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class TpuUtilsTest(parameterized.TestCase):
2323
dict(tpu_type="v4-128", version="v4", cores=128, workers=16),
2424
dict(tpu_type="v5litepod-16", version="v5litepod", cores=16, workers=4),
2525
dict(tpu_type="v3-64", version="v3", cores=64, workers=8),
26+
dict(tpu_type="v6e-8-1", version="v6e", cores=8, workers=2),
2627
)
2728
def test_infer_utils(self, tpu_type: str, version: str, cores: int, workers: int):
2829
self.assertEqual(version, infer_tpu_version(tpu_type))

axlearn/common/compiler_options.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ class NotTpuError(ValueError):
263263

264264
# TODO(markblee): Generalize to other accelerators.
265265
def infer_tpu_type(instance_type: str) -> str:
266-
"""Infers tpu type (e.g. v4-8) from instance type (e.g. tpu-v4-8 or v4-8)."""
266+
"""Infers tpu type (e.g. v4-8 or v6e-8-1) from instance type
267+
(e.g. tpu-v4-8, v4-8, tpu-v6e-8-1 or v6e-8-1)."""
267268
if not (instance_type and re.fullmatch(r"(tpu-)?v.+-\d+", instance_type)):
268269
raise NotTpuError(f"Invalid TPU instance: {instance_type}")
269270
return instance_type.replace("tpu-", "")
@@ -282,6 +283,8 @@ def infer_tpu_version(tpu_type: str) -> str:
282283
Raises:
283284
ValueError: if the TPU version string is unknown.
284285
"""
286+
if tpu_type.count("-") == 2:
287+
tpu_type = tpu_type[: tpu_type.rfind("-")]
285288
tpu_type = infer_tpu_type(tpu_type)
286289
tpu_version = tpu_type.rsplit("-", 1)[0] # split from the last occurrence of '-'
287290
# Resolve aliases like v5e to v5litepod, since in some cases (e.g. aot compilation) v5e is

axlearn/common/compiler_options_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_xsc_compiler_options(self):
8888

8989
@parameterized.parameters(
9090
dict(tpu_type="v5e-16", expected="v5litepod"),
91+
dict(tpu_type="v6e-8-1", expected="v6e"),
9192
)
9293
def test_tpu_version_alias(self, tpu_type: str, expected: str):
9394
self.assertEqual(expected, compiler_options.infer_tpu_version(tpu_type))

0 commit comments

Comments
 (0)