Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mrack/data/provisioning-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ aws: # aws provider config
tag:
name: my-tag
value: my-special-image
win-2022-latest: # find latest AMI via AWS SSM parameter
ssm: /aws/service/ami-windows-latest/Windows_Server-2022-English-Full-Base

flavors:
# list of available flavours to ask from provider for vm specs
Expand Down
39 changes: 38 additions & 1 deletion src/mrack/providers/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def name(self):
"""Get provider name."""
return self._name

@property
def ssm_client(self):
"""Get SSM client, creating it on first use."""
if self._ssm_client is None:
self._ssm_client = boto3.client("ssm")
return self._ssm_client

async def init(
self,
ssh_key,
Expand Down Expand Up @@ -88,6 +95,8 @@ async def init(
) from c_err

self.amis = []
self._ssm_client = None
self.ssm_resolved = {}
self.ssh_key = ssh_key
self.instance_tags = instance_tags
login_end = datetime.now()
Expand Down Expand Up @@ -117,6 +126,17 @@ def validate_tags_image_def(self, image_def):

return True

def validate_ssm_image_def(self, image_def):
"""Validate that SSM parameter definition for image is correct."""
if not isinstance(image_def, dict) or "ssm" not in image_def:
self.raise_image_def_error(image_def)

ssm_val = image_def.get("ssm")
if not isinstance(ssm_val, str):
self.raise_image_def_error(image_def)

return True

def get_image(self, req):
"""
Get a loaded image.
Expand Down Expand Up @@ -144,6 +164,15 @@ def get_image(self, req):
and tag["Value"] == tag_def["value"]
):
return ami
# by SSM parameter
elif isinstance(image_def, dict) and "ssm" in image_def:
self.validate_ssm_image_def(image_def)
ssm_path = image_def["ssm"]
resolved_id = self.ssm_resolved.get(ssm_path)
if resolved_id:
for ami in self.amis:
if ami.image_id == resolved_id:
return ami
# by AMI ID
elif isinstance(image_def, str):
for ami in self.amis:
Expand All @@ -152,7 +181,7 @@ def get_image(self, req):
else:
raise ValidationError(
f"{log_msg_start} Invalid image "
f"definition. Must be 'tags' definition or AMI ID"
f"definition. Must be 'tag', 'ssm', or AMI ID definition"
)
return None

Expand All @@ -174,6 +203,14 @@ def load_image(self, req):
name = tag_def["name"]
filters.append({"Name": f"tag:{name}", "Values": [tag_def["value"]]})

# by SSM parameter
elif isinstance(image_def, dict) and "ssm" in image_def:
ssm_path = image_def["ssm"]
response = self.ssm_client.get_parameter(Name=ssm_path)
Comment thread
mrizwan93 marked this conversation as resolved.
ami_id = response["Parameter"]["Value"]
self.ssm_resolved[ssm_path] = ami_id
filters.append({"Name": "image-id", "Values": [ami_id]})

# by AMI ID
elif isinstance(image_def, str):
filters.append({"Name": "image-id", "Values": [image_def]})
Expand Down
198 changes: 198 additions & 0 deletions tests/unit/test_aws_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Tests for AWS Provider - SSM parameter image resolution."""

from unittest.mock import MagicMock, patch

import pytest

from mrack.errors import ValidationError
from mrack.providers.aws import AWSProvider


class MockAMI:
"""Mock AWS AMI object."""

def __init__(
self, image_id, name=None, tags=None, creation_date="2024-01-01T00:00:00.000Z"
):
self.image_id = image_id
self.name = name
self.tags = tags
self.creation_date = creation_date


@pytest.fixture
def provider():
"""Create an AWSProvider with mocked boto3 clients."""
with patch("mrack.providers.aws.boto3"):
p = AWSProvider()
p.dsp_name = "AWS"
p.amis = []
p._ssm_client = MagicMock()
p.ssm_resolved = {}
p.ec2 = MagicMock()
return p


class TestValidateSSMImageDef:
"""Test validate_ssm_image_def method."""

def test_valid(self, provider):
image_def = {"ssm": "/aws/service/ami-windows-latest/Windows_Server-2022"}
assert provider.validate_ssm_image_def(image_def) is True

def test_missing_ssm_key(self, provider):
with pytest.raises(ValidationError):
provider.validate_ssm_image_def({"tag": "something"})

def test_ssm_not_string(self, provider):
with pytest.raises(ValidationError):
provider.validate_ssm_image_def({"ssm": 123})

def test_not_dict(self, provider):
with pytest.raises(ValidationError):
provider.validate_ssm_image_def("just-a-string")


class TestGetImageSSM:
"""Test get_image with SSM-based image definitions."""

def test_returns_cached_ami_after_resolution(self, provider):
ami = MockAMI("ami-resolved-123")
provider.amis = [ami]
provider.ssm_resolved = {
"/aws/service/ami-windows-latest/Win2022": "ami-resolved-123"
}

req = {
"name": "win-host",
"image": {"ssm": "/aws/service/ami-windows-latest/Win2022"},
}
result = provider.get_image(req)
assert result is ami

def test_returns_none_before_resolution(self, provider):
provider.amis = []
provider.ssm_resolved = {}

req = {
"name": "win-host",
"image": {"ssm": "/aws/service/ami-windows-latest/Win2022"},
}
result = provider.get_image(req)
assert result is None

def test_returns_none_when_resolved_but_not_cached(self, provider):
provider.amis = [MockAMI("ami-other")]
provider.ssm_resolved = {
"/aws/service/ami-windows-latest/Win2022": "ami-resolved-123"
}

req = {
"name": "win-host",
"image": {"ssm": "/aws/service/ami-windows-latest/Win2022"},
}
result = provider.get_image(req)
assert result is None

def test_tag_lookup_still_works(self, provider):
ami = MockAMI("ami-tag", tags=[{"Key": "env", "Value": "prod"}])
provider.amis = [ami]

req = {
"name": "host",
"image": {"tag": {"name": "env", "value": "prod"}},
}
result = provider.get_image(req)
assert result is ami

def test_ami_id_lookup_still_works(self, provider):
ami = MockAMI("ami-direct-123")
provider.amis = [ami]

req = {"name": "host", "image": "ami-direct-123"}
result = provider.get_image(req)
assert result is ami

def test_invalid_dict_raises(self, provider):
req = {"name": "host", "image": {"unknown": "value"}}
with pytest.raises(ValidationError, match="'tag', 'ssm', or AMI ID"):
provider.get_image(req)

def test_no_image_raises(self, provider):
req = {"name": "host"}
with pytest.raises(ValidationError, match="doesn't have image defined"):
provider.get_image(req)


class TestLoadImageSSM:
"""Test load_image with SSM-based image definitions."""

def test_resolves_ssm_and_loads_ami(self, provider):
ssm_path = "/aws/service/ami-windows-latest/Windows_Server-2022"
provider.ssm_client.get_parameter.return_value = {
"Parameter": {"Value": "ami-win2022-latest"}
}

mock_ami = MockAMI(
"ami-win2022-latest", creation_date="2024-06-01T00:00:00.000Z"
)
provider.ec2.images.filter.return_value = [mock_ami]

req = {"name": "win-host", "image": {"ssm": ssm_path}}
result = provider.load_image(req)

provider.ssm_client.get_parameter.assert_called_once_with(Name=ssm_path)
provider.ec2.images.filter.assert_called_once_with(
Filters=[{"Name": "image-id", "Values": ["ami-win2022-latest"]}]
)
assert result is mock_ami
assert provider.ssm_resolved[ssm_path] == "ami-win2022-latest"
assert mock_ami in provider.amis

def test_no_ami_found_raises(self, provider):
provider.ssm_client.get_parameter.return_value = {
"Parameter": {"Value": "ami-nonexistent"}
}
provider.ec2.images.filter.return_value = []

req = {"name": "win-host", "image": {"ssm": "/aws/service/some-path"}}
with pytest.raises(ValidationError, match="Cannot find image"):
provider.load_image(req)

def test_returns_newest_when_multiple(self, provider):
provider.ssm_client.get_parameter.return_value = {
"Parameter": {"Value": "ami-123"}
}

old_ami = MockAMI("ami-123", creation_date="2024-01-01T00:00:00.000Z")
new_ami = MockAMI("ami-123", creation_date="2024-06-01T00:00:00.000Z")
provider.ec2.images.filter.return_value = [old_ami, new_ami]

req = {"name": "host", "image": {"ssm": "/aws/service/path"}}
result = provider.load_image(req)
assert result is new_ami

def test_tag_load_still_works(self, provider):
mock_ami = MockAMI("ami-tag-1", creation_date="2024-01-01T00:00:00.000Z")
provider.ec2.images.filter.return_value = [mock_ami]

req = {
"name": "host",
"image": {"tag": {"name": "env", "value": "prod"}},
}
result = provider.load_image(req)
provider.ec2.images.filter.assert_called_once_with(
Filters=[{"Name": "tag:env", "Values": ["prod"]}]
)
assert result is mock_ami

def test_ami_id_load_still_works(self, provider):
mock_ami = MockAMI("ami-direct", creation_date="2024-01-01T00:00:00.000Z")
provider.ec2.images.filter.return_value = [mock_ami]

req = {"name": "host", "image": "ami-direct"}
result = provider.load_image(req)
provider.ec2.images.filter.assert_called_once_with(
Filters=[{"Name": "image-id", "Values": ["ami-direct"]}]
)
assert result is mock_ami
38 changes: 38 additions & 0 deletions tests/unit/test_aws_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,44 @@ def _aws_provisioning_config(user_data=None):
)


class TestAWSTransformerSSMImage:
"""Test that SSM image definitions pass through the transformer correctly."""

@pytest.mark.asyncio
async def test_ssm_image_def_in_requirement(self):
"""SSM image dict should be passed through as-is into the requirement."""
providers.register(AWS, AWSProvider)
transformer = MockedAWSTransformer()

ssm_path = (
"/aws/service/ami-windows-latest/Windows_Server-2022-English-Full-Base"
)
ssm_image = {"ssm": ssm_path}
aws_cfg = {
"images": {
"win-2022": ssm_image,
"rhel-8.5": "ami-rhel-8-5",
},
"flavors": {"default": "t2.nano"},
"keypair": "mrack-keypair.pem",
"security_group": "sg-something",
"security_groups": ["sg-something"],
"credentials_file": "aws.key",
"profile": "default",
"spot": True,
"instance_tags": {"Name": "mrack-runner"},
}
config = ProvisioningConfig(
{"aws": aws_cfg, "users": {"win-2022": "Administrator"}}
)
hosts = [_host("ad", "ad", "ad", "win-2022")]
metadata = {"domains": [{"name": DOMAIN, "type": "mixed", "hosts": hosts}]}
await transformer.init(config, metadata)

req = transformer.create_host_requirement(hosts[0])
assert req["image"] == ssm_image


class TestAWSTransformerUserData:
"""Test the AWS Transformer's user_data handling."""

Expand Down
Loading