Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion xmanager/cloud/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@
8: 'a2-highgpu-8g',
16: 'a2-megagpu-16g',
}
_L4_GPUS_TO_MACHINE_TYPE = {
(1, 4): 'g2-standard-4',
(1, 8): 'g2-standard-8',
(1, 12): 'g2-standard-12',
(1, 16): 'g2-standard-16',
(1, 32): 'g2-standard-32',
(2, 24): 'g2-standard-24',
(4, 48): 'g2-standard-48',
(8, 96): 'g2-standard-96',
}

_CLOUD_TPU_ACCELERATOR_TYPES = {
xm.ResourceType.TPU_V2: 'TPU_V2',
Expand Down Expand Up @@ -97,6 +107,28 @@
),
}

def aip_v1_gpu_accelerator_type_str(gpu_type: xm.GpuType) -> str:
tesla_architectures = {xm.ResourceType.P4, xm.ResourceType.T4, xm.ResourceType.P100, xm.ResourceType.V100, xm.ResourceType.A100}
match gpu_type:
case _ if gpu_type in tesla_architectures:
return f"NVIDIA_TESLA_{gpu_type.name.upper()}"
case xm.ResourceType.L4:
return 'NVIDIA_L4'
case xm.ResourceType.L4_24TH:
return 'NVIDIA_L4'
case xm.ResourceType.A100_80GIB:
return 'NVIDIA_A100_80GB'
case xm.ResourceType.H100:
return 'NVIDIA_H100_80GB'
case xm.ResourceType.H200:
return 'NVIDIA_H200_141GB'
case xm.ResourceType.B200:
return 'NVIDIA_B200'
case _:
raise ValueError(
f'Unsupported GPU type {gpu_type}. Supported types are: {GpuType}'
)

# Hide noisy warning regarding:
# `file_cache is unavailable when using oauth2client >= 4.0.0 or google-auth`
logging.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR)
Expand Down Expand Up @@ -298,14 +330,16 @@ def get_machine_spec(job: xm.Job) -> Dict[str, Any]:
for resource, value in requirements.task_requirements.items():
accelerator_type = None
if resource in xm.GpuType:
accelerator_type = 'NVIDIA_TESLA_' + str(resource).upper()
accelerator_type = aip_v1_gpu_accelerator_type_str(resource)
elif resource in xm.TpuType:
accelerator_type = _CLOUD_TPU_ACCELERATOR_TYPES[resource]
if accelerator_type:
spec['accelerator_type'] = aip_v1.AcceleratorType[accelerator_type]
spec['accelerator_count'] = int(value)
accelerator = spec.get('accelerator_type', None)
if accelerator and accelerator == aip_v1.AcceleratorType.NVIDIA_TESLA_A100:
print(f'Available A100 machine types (gpus: machine_type): {_A100_GPUS_TO_MACHINE_TYPE}')

for gpus, machine_type in sorted(_A100_GPUS_TO_MACHINE_TYPE.items()):
if spec['accelerator_count'] <= gpus:
spec['machine_type'] = machine_type
Expand All @@ -316,6 +350,30 @@ def get_machine_spec(job: xm.Job) -> Dict[str, Any]:
spec['accelerator_count']
)
)
elif accelerator and accelerator == aip_v1.AcceleratorType.NVIDIA_L4:
print(f'Available L4 machine types ((gpus, cpus): machine_type): {_L4_GPUS_TO_MACHINE_TYPE}')

required_gpus = spec['accelerator_count']
required_cpus = requirements.task_requirements.get(xm.ResourceType.CPU, None)
gpus_matches = lambda gpus: spec['accelerator_count'] <= gpus
cpus_matches = lambda cpus: required_cpus is None or cpus == required_cpus

l4_candidates = [
(machine_type, (gpus, cpus))
for (gpus, cpus), machine_type in _L4_GPUS_TO_MACHINE_TYPE.items()
if gpus_matches(gpus) and cpus_matches(cpus)
]

if not l4_candidates:
cpu_str = f' with {required_cpus} CPUs' if required_cpus else ''
raise ValueError(
f'l4={required_gpus}{cpu_str} does not fit in any valid machine type.'
)

# Find the best fit (smallest machine that satisfies the requirements).
# The key for sorting is (gpus, cpus).
best_fit_machine_type, _ = min(l4_candidates, key=lambda item: item[1])
spec['machine_type'] = best_fit_machine_type
elif (
accelerator == aip_v1.AcceleratorType.TPU_V2
or accelerator == aip_v1.AcceleratorType.TPU_V3
Expand Down
58 changes: 57 additions & 1 deletion xmanager/cloud/vertex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import os
import unittest
from unittest import mock
import io
from contextlib import redirect_stdout

from absl.testing import parameterized
from google import auth
from google.auth import credentials
from google.cloud import aiplatform
Expand All @@ -31,7 +34,7 @@
from xmanager.cloud import vertex # pylint: disable=g-bad-import-order


class VertexTest(unittest.TestCase):
class VertexTest(parameterized.TestCase):

@mock.patch.object(xm_auth, 'get_service_account')
@mock.patch.object(auth, 'default')
Expand Down Expand Up @@ -155,6 +158,59 @@ def test_get_machine_spec_a100(self):
},
)

@parameterized.parameters(
{'cpus': 4, 'gpus': 1, 'expected': 'g2-standard-4'},
{'cpus': 8, 'gpus': 1, 'expected': 'g2-standard-8'},
{'cpus': 12, 'gpus': 1, 'expected': 'g2-standard-12'},
{'cpus': 16, 'gpus': 1, 'expected': 'g2-standard-16'},
{'cpus': 32, 'gpus': 1, 'expected': 'g2-standard-32'},
{'cpus': 24, 'gpus': 2, 'expected': 'g2-standard-24'},
{'cpus': 48, 'gpus': 4, 'expected': 'g2-standard-48'},
{'cpus': 96, 'gpus': 8, 'expected': 'g2-standard-96'},
)
def test_get_machine_spec_l4(self, cpus, gpus, expected):
job = xm.Job(
executable=local_executables.GoogleContainerRegistryImage('name', ''),
executor=local_executors.Vertex(
requirements=xm.JobRequirements(l4=gpus, cpu=cpus)
),
args={},
)
machine_spec = vertex.get_machine_spec(job)
self.assertDictEqual(
machine_spec,
{
'machine_type': expected,
'accelerator_type': vertex.aip_v1.AcceleratorType.NVIDIA_L4,
'accelerator_count': gpus,
},
)

@parameterized.parameters(
{'cpus': 3, 'gpus': 1},
{'cpus': 4, 'gpus': 2},
{'cpus': 25, 'gpus': 2},
{'cpus': 41, 'gpus': 4},
{'cpus': 48, 'gpus': 8},
)
def test_get_machine_spec_l4_failure(self, cpus, gpus):
job = xm.Job(
executable=local_executables.GoogleContainerRegistryImage('name', ''),
executor=local_executors.Vertex(
requirements=xm.JobRequirements(l4=gpus, cpu=cpus)
),
args={},
)
f = io.StringIO()
with redirect_stdout(f), self.assertRaises(ValueError) as cm:
vertex.get_machine_spec(job)

self.assertIn('Available L4 machine types', f.getvalue())
self.assertIn(
f'l4={gpus} with {cpus}.0 CPUs does not fit in any valid machine type.',
str(cm.exception),
)

def test_get_machine_spec_tpu(self):
job = xm.Job(
executable=local_executables.GoogleContainerRegistryImage('name', ''),
Expand Down
2 changes: 2 additions & 0 deletions xmanager/xm/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class ResourceType(enum.Enum, metaclass=_CaseInsensitiveResourceTypeMeta):
LOCAL_GPU = 100006
P4 = 21
T4 = 22
L4 = 11
Copy link
Copy Markdown
Author

@hartikainen hartikainen Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what this value should be.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have L4_24TH = 68 which refers to the same GPU family, adding a separate L4 = 11 might create confusion. To keep the resource definitions clear and avoid ambiguity, I suggest we remove the new L4 entry and exclusively use L4_24TH when referring to L4 GPUs. The logic in vertex.py can then be simplified to handle just the one L4_24TH type, which it correctly maps to 'NVIDIA_L4'.
This would also resolve the uncertainty around what value L4 should have.

L4_24TH = 68
P100 = 14
V100 = 17
Expand Down Expand Up @@ -194,6 +195,7 @@ def __new__(cls, value: int) -> ResourceType:
# LOCAL_GPU is missing as only specific GPU types should be added.
ResourceType.P4,
ResourceType.T4,
ResourceType.L4,
ResourceType.L4_24TH,
ResourceType.P100,
ResourceType.V100,
Expand Down