Skip to content

Commit b9396f5

Browse files
Uniquely identify AWS resources for Dask clusters running in the same ECS cluster (#474)
* Introduce unique IDs for Dask clusters running in the same ECS cluster. * Black passing * Use cluster.name instead of creating a new dask cluster id. * Unused import * Unit tests
1 parent 8b5ab97 commit b9396f5

File tree

3 files changed

+183
-27
lines changed

3 files changed

+183
-27
lines changed

dask_cloudprovider/aws/ecs.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import logging
3-
import uuid
43
import warnings
54
import weakref
65
from typing import List, Optional
@@ -224,9 +223,9 @@ async def start(self):
224223
"awsvpcConfiguration": {
225224
"subnets": self._vpc_subnets,
226225
"securityGroups": self._security_groups,
227-
"assignPublicIp": "ENABLED"
228-
if self._use_public_ip
229-
else "DISABLED",
226+
"assignPublicIp": (
227+
"ENABLED" if self._use_public_ip else "DISABLED"
228+
),
230229
}
231230
},
232231
}
@@ -461,7 +460,9 @@ class ECSCluster(SpecCluster, ConfigMixin):
461460
This creates a dask scheduler and workers on an existing ECS cluster.
462461
463462
All the other required resources such as roles, task definitions, tasks, etc
464-
will be created automatically like in :class:`FargateCluster`.
463+
will be created automatically like in :class:`FargateCluster`. Resource names will
464+
include the value of `self.name` to uniquely associate them with this cluster, and
465+
they will also be tagged with `dask_cluster_name` using the same value.
465466
466467
Parameters
467468
----------
@@ -579,9 +580,11 @@ class ECSCluster(SpecCluster, ConfigMixin):
579580
Defaults to ``None`` which results in a new cluster being created for you.
580581
cluster_name_template: str (optional)
581582
A template to use for the cluster name if ``cluster_arn`` is set to
582-
``None``.
583+
``None``. Valid substitution variables are:
583584
584-
Defaults to ``'dask-{uuid}'``
585+
``name`` <= self.name (usually a UUID)
586+
587+
Defaults to ``'dask-{name}'``
585588
execution_role_arn: str (optional)
586589
The ARN of an existing IAM role to use for ECS execution.
587590
@@ -626,9 +629,12 @@ class ECSCluster(SpecCluster, ConfigMixin):
626629
627630
Default ``None`` (one will be created called ``dask-ecs``)
628631
cloudwatch_logs_stream_prefix: str (optional)
629-
Prefix for log streams.
632+
Prefix for log streams. Valid substitution variables are:
633+
634+
``name`` <= self.name (usually a UUID)
635+
``cluster_name`` <= self.cluster_name (ECS cluster name)
630636
631-
Defaults to the cluster name.
637+
Defaults to ``{cluster_name}/{name}``.
632638
cloudwatch_logs_default_retention: int (optional)
633639
Retention for logs in days. For use when log group is auto created.
634640
@@ -921,7 +927,10 @@ async def _start(
921927
if self._cloudwatch_logs_stream_prefix is None:
922928
self._cloudwatch_logs_stream_prefix = self.config.get(
923929
"cloudwatch_logs_stream_prefix"
924-
).format(cluster_name=self.cluster_name)
930+
).format(
931+
cluster_name=self.cluster_name,
932+
name=self.name,
933+
)
925934

926935
if self.cloudwatch_logs_group is None:
927936
self.cloudwatch_logs_group = (
@@ -1025,7 +1034,12 @@ def _new_worker_name(self, worker_number):
10251034

10261035
@property
10271036
def tags(self):
1028-
return {**self._tags, **DEFAULT_TAGS, "cluster": self.cluster_name}
1037+
return {
1038+
**self._tags,
1039+
**DEFAULT_TAGS,
1040+
"cluster": self.cluster_name,
1041+
"dask_cluster_name": self.name,
1042+
}
10291043

10301044
async def _create_cluster(self):
10311045
if not self._fargate_scheduler or not self._fargate_workers:
@@ -1038,7 +1052,10 @@ async def _create_cluster(self):
10381052
self.cluster_name = dask.config.expand_environment_variables(
10391053
self._cluster_name_template
10401054
)
1041-
self.cluster_name = self.cluster_name.format(uuid=str(uuid.uuid4())[:10])
1055+
self.cluster_name = self.cluster_name.format(
1056+
name=self.name,
1057+
uuid=self.name, # backwards-compatible
1058+
)
10421059
async with self._client("ecs") as ecs:
10431060
response = await ecs.create_cluster(
10441061
clusterName=self.cluster_name,
@@ -1059,7 +1076,7 @@ async def _delete_cluster(self):
10591076

10601077
@property
10611078
def _execution_role_name(self):
1062-
return "{}-{}".format(self.cluster_name, "execution-role")
1079+
return "dask-{}-execution-role".format(self.name)
10631080

10641081
async def _create_execution_role(self):
10651082
async with self._client("iam") as iam:
@@ -1099,7 +1116,7 @@ async def _create_execution_role(self):
10991116

11001117
@property
11011118
def _task_role_name(self):
1102-
return "{}-{}".format(self.cluster_name, "task-role")
1119+
return "dask-{}-task-role".format(self.name)
11031120

11041121
async def _create_task_role(self):
11051122
async with self._client("iam") as iam:
@@ -1141,6 +1158,8 @@ async def _delete_role(self, role):
11411158
await iam.delete_role(RoleName=role)
11421159

11431160
async def _create_cloudwatch_logs_group(self):
1161+
# The log group does not include `name` because it is shared by all Dask ECS clusters. But,
1162+
# log streams do because they are specific to each Dask cluster.
11441163
log_group_name = "dask-ecs"
11451164
async with self._client("logs") as logs:
11461165
groups = await logs.describe_log_groups()
@@ -1160,23 +1179,29 @@ async def _create_cloudwatch_logs_group(self):
11601179
# Note: Not cleaning up the logs here as they may be useful after the cluster is destroyed
11611180
return log_group_name
11621181

1182+
@property
1183+
def _security_group_name(self):
1184+
return "dask-{}-security-group".format(self.name)
1185+
11631186
async def _create_security_groups(self):
11641187
async with self._client("ec2") as client:
11651188
group = await create_default_security_group(
1166-
client, self.cluster_name, self._vpc, self.tags
1189+
client, self._security_group_name, self._vpc, self.tags
11671190
)
11681191
weakref.finalize(self, self.sync, self._delete_security_groups)
11691192
return [group]
11701193

11711194
async def _delete_security_groups(self):
11721195
timeout = Timeout(
1173-
30, "Unable to delete AWS security group " + self.cluster_name, warn=True
1196+
30,
1197+
"Unable to delete AWS security group {}".format(self._security_group_name),
1198+
warn=True,
11741199
)
11751200
async with self._client("ec2") as ec2:
11761201
while timeout.run():
11771202
try:
11781203
await ec2.delete_security_group(
1179-
GroupName=self.cluster_name, DryRun=False
1204+
GroupName=self._security_group_name, DryRun=False
11801205
)
11811206
except Exception:
11821207
await asyncio.sleep(2)
@@ -1185,7 +1210,7 @@ async def _delete_security_groups(self):
11851210
async def _create_scheduler_task_definition_arn(self):
11861211
async with self._client("ecs") as ecs:
11871212
response = await ecs.register_task_definition(
1188-
family="{}-{}".format(self.cluster_name, "scheduler"),
1213+
family="dask-{}-scheduler".format(self.name),
11891214
taskRoleArn=self._task_role_arn,
11901215
executionRoleArn=self._execution_role_arn,
11911216
networkMode="awsvpc",
@@ -1223,14 +1248,18 @@ async def _create_scheduler_task_definition_arn(self):
12231248
"awslogs-create-group": "true",
12241249
},
12251250
},
1226-
"mountPoints": self._mount_points
1227-
if self._mount_points and self._mount_volumes_on_scheduler
1228-
else [],
1251+
"mountPoints": (
1252+
self._mount_points
1253+
if self._mount_points and self._mount_volumes_on_scheduler
1254+
else []
1255+
),
12291256
}
12301257
],
1231-
volumes=self._volumes
1232-
if self._volumes and self._mount_volumes_on_scheduler
1233-
else [],
1258+
volumes=(
1259+
self._volumes
1260+
if self._volumes and self._mount_volumes_on_scheduler
1261+
else []
1262+
),
12341263
requiresCompatibilities=["FARGATE"] if self._fargate_scheduler else [],
12351264
runtimePlatform={"cpuArchitecture": self._cpu_architecture},
12361265
cpu=str(self._scheduler_cpu),
@@ -1255,7 +1284,7 @@ async def _create_worker_task_definition_arn(self):
12551284
)
12561285
async with self._client("ecs") as ecs:
12571286
response = await ecs.register_task_definition(
1258-
family="{}-{}".format(self.cluster_name, "worker"),
1287+
family="dask-{}-worker".format(self.name),
12591288
taskRoleArn=self._task_role_arn,
12601289
executionRoleArn=self._execution_role_arn,
12611290
networkMode="awsvpc",

dask_cloudprovider/aws/tests/test_ecs.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from unittest import mock
2+
from unittest.mock import AsyncMock
3+
14
import pytest
25

36
aiobotocore = pytest.importorskip("aiobotocore")
@@ -6,3 +9,127 @@
69
def test_import():
710
from dask_cloudprovider.aws import ECSCluster # noqa
811
from dask_cloudprovider.aws import FargateCluster # noqa
12+
13+
14+
def test_reuse_ecs_cluster():
15+
from dask_cloudprovider.aws import ECSCluster # noqa
16+
17+
fc1_name = "Spooky"
18+
fc2_name = "Weevil"
19+
vpc_name = "MyNetwork"
20+
vpc_subnets = ["MySubnet1", "MySubnet2"]
21+
cluster_arn = "CompletelyMadeUp"
22+
cluster_name = "Crunchy"
23+
log_group_name = "dask-ecs"
24+
25+
expected_execution_role_name1 = f"dask-{fc1_name}-execution-role"
26+
expected_task_role_name1 = f"dask-{fc1_name}-task-role"
27+
expected_log_stream_prefix1 = f"{cluster_name}/{fc1_name}"
28+
expected_security_group_name1 = f"dask-{fc1_name}-security-group"
29+
expected_scheduler_task_name1 = f"dask-{fc1_name}-scheduler"
30+
expected_worker_task_name1 = f"dask-{fc1_name}-worker"
31+
32+
expected_execution_role_name2 = f"dask-{fc2_name}-execution-role"
33+
expected_task_role_name2 = f"dask-{fc2_name}-task-role"
34+
expected_log_stream_prefix2 = f"{cluster_name}/{fc2_name}"
35+
expected_security_group_name2 = f"dask-{fc2_name}-security-group"
36+
expected_scheduler_task_name2 = f"dask-{fc2_name}-scheduler"
37+
expected_worker_task_name2 = f"dask-{fc2_name}-worker"
38+
39+
mock_client = AsyncMock()
40+
mock_client.describe_clusters.return_value = {
41+
"clusters": [{"clusterName": cluster_name}]
42+
}
43+
mock_client.list_account_settings.return_value = {"settings": {"value": "enabled"}}
44+
mock_client.create_role.return_value = {"Role": {"Arn": "Random"}}
45+
mock_client.describe_log_groups.return_value = {"logGroups": []}
46+
47+
class MockSession:
48+
class MockClient:
49+
async def __aenter__(self, *args, **kwargs):
50+
return mock_client
51+
52+
async def __aexit__(self, *args, **kwargs):
53+
return
54+
55+
def create_client(self, *args, **kwargs):
56+
return MockSession.MockClient()
57+
58+
with (
59+
mock.patch(
60+
"dask_cloudprovider.aws.ecs.get_session", return_value=MockSession()
61+
),
62+
mock.patch("distributed.deploy.spec.SpecCluster._start"),
63+
mock.patch("weakref.finalize"),
64+
):
65+
# Make ourselves a test cluster
66+
fc1 = ECSCluster(
67+
name=fc1_name,
68+
cluster_arn=cluster_arn,
69+
vpc=vpc_name,
70+
subnets=vpc_subnets,
71+
skip_cleanup=True,
72+
)
73+
# Are we re-using the existing ECS cluster?
74+
assert fc1.cluster_name == cluster_name
75+
# Have we made completely unique AWS resources to run on that cluster?
76+
assert fc1._execution_role_name == expected_execution_role_name1
77+
assert fc1._task_role_name == expected_task_role_name1
78+
assert fc1._cloudwatch_logs_stream_prefix == expected_log_stream_prefix1
79+
assert (
80+
fc1.scheduler_spec["options"]["log_stream_prefix"]
81+
== expected_log_stream_prefix1
82+
)
83+
assert (
84+
fc1.new_spec["options"]["log_stream_prefix"] == expected_log_stream_prefix1
85+
)
86+
assert fc1.cloudwatch_logs_group == log_group_name
87+
assert fc1.scheduler_spec["options"]["log_group"] == log_group_name
88+
assert fc1.new_spec["options"]["log_group"] == log_group_name
89+
sg_calls = mock_client.create_security_group.call_args_list
90+
assert len(sg_calls) == 1
91+
assert sg_calls[0].kwargs["GroupName"] == expected_security_group_name1
92+
td_calls = mock_client.register_task_definition.call_args_list
93+
assert len(td_calls) == 2
94+
assert td_calls[0].kwargs["family"] == expected_scheduler_task_name1
95+
assert td_calls[1].kwargs["family"] == expected_worker_task_name1
96+
97+
# Reset mocks ready for second cluster
98+
mock_client.create_security_group.reset_mock()
99+
mock_client.register_task_definition.reset_mock()
100+
101+
# Make ourselves a second test cluster on the same ECS cluster
102+
fc2 = ECSCluster(
103+
name=fc2_name,
104+
cluster_arn=cluster_arn,
105+
vpc=vpc_name,
106+
subnets=vpc_subnets,
107+
skip_cleanup=True,
108+
)
109+
# Are we re-using the existing ECS cluster?
110+
assert fc2.cluster_name == cluster_name
111+
# Have we made completely unique AWS resources to run on that cluster?
112+
assert fc2._execution_role_name == expected_execution_role_name2
113+
assert fc2._task_role_name == expected_task_role_name2
114+
assert fc2._cloudwatch_logs_stream_prefix == expected_log_stream_prefix2
115+
assert (
116+
fc2.scheduler_spec["options"]["log_stream_prefix"]
117+
== expected_log_stream_prefix2
118+
)
119+
assert (
120+
fc2.new_spec["options"]["log_stream_prefix"] == expected_log_stream_prefix2
121+
)
122+
assert fc2.cloudwatch_logs_group == log_group_name
123+
assert fc2.scheduler_spec["options"]["log_group"] == log_group_name
124+
assert fc2.new_spec["options"]["log_group"] == log_group_name
125+
sg_calls = mock_client.create_security_group.call_args_list
126+
assert len(sg_calls) == 1
127+
assert sg_calls[0].kwargs["GroupName"] == expected_security_group_name2
128+
td_calls = mock_client.register_task_definition.call_args_list
129+
assert len(td_calls) == 2
130+
assert td_calls[0].kwargs["family"] == expected_scheduler_task_name2
131+
assert td_calls[1].kwargs["family"] == expected_worker_task_name2
132+
133+
# Finish up
134+
fc1.close()
135+
fc2.close()

dask_cloudprovider/cloudprovider.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ cloudprovider:
1717
image: "daskdev/dask:latest" # Docker image to use for non GPU tasks
1818
cpu_architecture: "X86_64" # Runtime platform CPU architecture
1919
gpu_image: "rapidsai/rapidsai:latest" # Docker image to use for GPU tasks
20-
cluster_name_template: "dask-{uuid}" # Template to use when creating a cluster
20+
cluster_name_template: "dask-{name}" # Template to use when creating a cluster
2121
cluster_arn: "" # ARN of existing ECS cluster to use (if not set one will be created)
2222
execution_role_arn: "" # Arn of existing execution role to use (if not set one will be created)
2323
task_role_arn: "" # Arn of existing task role to use (if not set one will be created)
2424
task_role_policies: [] # List of policy arns to attach to tasks (e.g S3 read only access)
2525
# platform_version: "LATEST" # Fargate platformVersion string like "1.4.0" or "LATEST"
2626

2727
cloudwatch_logs_group: "" # Name of existing cloudwatch logs group to use (if not set one will be created)
28-
cloudwatch_logs_stream_prefix: "{cluster_name}" # Stream prefix template
28+
cloudwatch_logs_stream_prefix: "{cluster_name}/{name}" # Stream prefix template
2929
cloudwatch_logs_default_retention: 30 # Number of days to retain logs (only applied if not using existing group)
3030

3131
vpc: "default" # VPC to use for tasks

0 commit comments

Comments
 (0)