diff --git a/src/mrack/data/provisioning-config.yaml b/src/mrack/data/provisioning-config.yaml index 575a1927..b9d5827c 100644 --- a/src/mrack/data/provisioning-config.yaml +++ b/src/mrack/data/provisioning-config.yaml @@ -60,6 +60,8 @@ aws: # aws provider config tag: name: my-tag value: my-special-image + win-2022-latest: # find latest AMI via AWS SSM parameter + ssm: /aws/service/ami-windows-latest/Windows_Server-2022-English-Full-Base flavors: # list of available flavours to ask from provider for vm specs diff --git a/src/mrack/providers/aws.py b/src/mrack/providers/aws.py index 1f883bbf..ac10e6a6 100644 --- a/src/mrack/providers/aws.py +++ b/src/mrack/providers/aws.py @@ -61,6 +61,13 @@ def name(self): """Get provider name.""" return self._name + @property + def ssm_client(self): + """Get SSM client, creating it on first use.""" + if self._ssm_client is None: + self._ssm_client = boto3.client("ssm") + return self._ssm_client + async def init( self, ssh_key, @@ -88,6 +95,8 @@ async def init( ) from c_err self.amis = [] + self._ssm_client = None + self.ssm_resolved = {} self.ssh_key = ssh_key self.instance_tags = instance_tags login_end = datetime.now() @@ -117,6 +126,17 @@ def validate_tags_image_def(self, image_def): return True + def validate_ssm_image_def(self, image_def): + """Validate that SSM parameter definition for image is correct.""" + if not isinstance(image_def, dict) or "ssm" not in image_def: + self.raise_image_def_error(image_def) + + ssm_val = image_def.get("ssm") + if not isinstance(ssm_val, str): + self.raise_image_def_error(image_def) + + return True + def get_image(self, req): """ Get a loaded image. @@ -144,6 +164,15 @@ def get_image(self, req): and tag["Value"] == tag_def["value"] ): return ami + # by SSM parameter + elif isinstance(image_def, dict) and "ssm" in image_def: + self.validate_ssm_image_def(image_def) + ssm_path = image_def["ssm"] + resolved_id = self.ssm_resolved.get(ssm_path) + if resolved_id: + for ami in self.amis: + if ami.image_id == resolved_id: + return ami # by AMI ID elif isinstance(image_def, str): for ami in self.amis: @@ -152,7 +181,7 @@ def get_image(self, req): else: raise ValidationError( f"{log_msg_start} Invalid image " - f"definition. Must be 'tags' definition or AMI ID" + f"definition. Must be 'tag', 'ssm', or AMI ID definition" ) return None @@ -174,6 +203,14 @@ def load_image(self, req): name = tag_def["name"] filters.append({"Name": f"tag:{name}", "Values": [tag_def["value"]]}) + # by SSM parameter + elif isinstance(image_def, dict) and "ssm" in image_def: + ssm_path = image_def["ssm"] + response = self.ssm_client.get_parameter(Name=ssm_path) + ami_id = response["Parameter"]["Value"] + self.ssm_resolved[ssm_path] = ami_id + filters.append({"Name": "image-id", "Values": [ami_id]}) + # by AMI ID elif isinstance(image_def, str): filters.append({"Name": "image-id", "Values": [image_def]}) diff --git a/tests/unit/test_aws_provider.py b/tests/unit/test_aws_provider.py new file mode 100644 index 00000000..9c7568c7 --- /dev/null +++ b/tests/unit/test_aws_provider.py @@ -0,0 +1,198 @@ +"""Tests for AWS Provider - SSM parameter image resolution.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from mrack.errors import ValidationError +from mrack.providers.aws import AWSProvider + + +class MockAMI: + """Mock AWS AMI object.""" + + def __init__( + self, image_id, name=None, tags=None, creation_date="2024-01-01T00:00:00.000Z" + ): + self.image_id = image_id + self.name = name + self.tags = tags + self.creation_date = creation_date + + +@pytest.fixture +def provider(): + """Create an AWSProvider with mocked boto3 clients.""" + with patch("mrack.providers.aws.boto3"): + p = AWSProvider() + p.dsp_name = "AWS" + p.amis = [] + p._ssm_client = MagicMock() + p.ssm_resolved = {} + p.ec2 = MagicMock() + return p + + +class TestValidateSSMImageDef: + """Test validate_ssm_image_def method.""" + + def test_valid(self, provider): + image_def = {"ssm": "/aws/service/ami-windows-latest/Windows_Server-2022"} + assert provider.validate_ssm_image_def(image_def) is True + + def test_missing_ssm_key(self, provider): + with pytest.raises(ValidationError): + provider.validate_ssm_image_def({"tag": "something"}) + + def test_ssm_not_string(self, provider): + with pytest.raises(ValidationError): + provider.validate_ssm_image_def({"ssm": 123}) + + def test_not_dict(self, provider): + with pytest.raises(ValidationError): + provider.validate_ssm_image_def("just-a-string") + + +class TestGetImageSSM: + """Test get_image with SSM-based image definitions.""" + + def test_returns_cached_ami_after_resolution(self, provider): + ami = MockAMI("ami-resolved-123") + provider.amis = [ami] + provider.ssm_resolved = { + "/aws/service/ami-windows-latest/Win2022": "ami-resolved-123" + } + + req = { + "name": "win-host", + "image": {"ssm": "/aws/service/ami-windows-latest/Win2022"}, + } + result = provider.get_image(req) + assert result is ami + + def test_returns_none_before_resolution(self, provider): + provider.amis = [] + provider.ssm_resolved = {} + + req = { + "name": "win-host", + "image": {"ssm": "/aws/service/ami-windows-latest/Win2022"}, + } + result = provider.get_image(req) + assert result is None + + def test_returns_none_when_resolved_but_not_cached(self, provider): + provider.amis = [MockAMI("ami-other")] + provider.ssm_resolved = { + "/aws/service/ami-windows-latest/Win2022": "ami-resolved-123" + } + + req = { + "name": "win-host", + "image": {"ssm": "/aws/service/ami-windows-latest/Win2022"}, + } + result = provider.get_image(req) + assert result is None + + def test_tag_lookup_still_works(self, provider): + ami = MockAMI("ami-tag", tags=[{"Key": "env", "Value": "prod"}]) + provider.amis = [ami] + + req = { + "name": "host", + "image": {"tag": {"name": "env", "value": "prod"}}, + } + result = provider.get_image(req) + assert result is ami + + def test_ami_id_lookup_still_works(self, provider): + ami = MockAMI("ami-direct-123") + provider.amis = [ami] + + req = {"name": "host", "image": "ami-direct-123"} + result = provider.get_image(req) + assert result is ami + + def test_invalid_dict_raises(self, provider): + req = {"name": "host", "image": {"unknown": "value"}} + with pytest.raises(ValidationError, match="'tag', 'ssm', or AMI ID"): + provider.get_image(req) + + def test_no_image_raises(self, provider): + req = {"name": "host"} + with pytest.raises(ValidationError, match="doesn't have image defined"): + provider.get_image(req) + + +class TestLoadImageSSM: + """Test load_image with SSM-based image definitions.""" + + def test_resolves_ssm_and_loads_ami(self, provider): + ssm_path = "/aws/service/ami-windows-latest/Windows_Server-2022" + provider.ssm_client.get_parameter.return_value = { + "Parameter": {"Value": "ami-win2022-latest"} + } + + mock_ami = MockAMI( + "ami-win2022-latest", creation_date="2024-06-01T00:00:00.000Z" + ) + provider.ec2.images.filter.return_value = [mock_ami] + + req = {"name": "win-host", "image": {"ssm": ssm_path}} + result = provider.load_image(req) + + provider.ssm_client.get_parameter.assert_called_once_with(Name=ssm_path) + provider.ec2.images.filter.assert_called_once_with( + Filters=[{"Name": "image-id", "Values": ["ami-win2022-latest"]}] + ) + assert result is mock_ami + assert provider.ssm_resolved[ssm_path] == "ami-win2022-latest" + assert mock_ami in provider.amis + + def test_no_ami_found_raises(self, provider): + provider.ssm_client.get_parameter.return_value = { + "Parameter": {"Value": "ami-nonexistent"} + } + provider.ec2.images.filter.return_value = [] + + req = {"name": "win-host", "image": {"ssm": "/aws/service/some-path"}} + with pytest.raises(ValidationError, match="Cannot find image"): + provider.load_image(req) + + def test_returns_newest_when_multiple(self, provider): + provider.ssm_client.get_parameter.return_value = { + "Parameter": {"Value": "ami-123"} + } + + old_ami = MockAMI("ami-123", creation_date="2024-01-01T00:00:00.000Z") + new_ami = MockAMI("ami-123", creation_date="2024-06-01T00:00:00.000Z") + provider.ec2.images.filter.return_value = [old_ami, new_ami] + + req = {"name": "host", "image": {"ssm": "/aws/service/path"}} + result = provider.load_image(req) + assert result is new_ami + + def test_tag_load_still_works(self, provider): + mock_ami = MockAMI("ami-tag-1", creation_date="2024-01-01T00:00:00.000Z") + provider.ec2.images.filter.return_value = [mock_ami] + + req = { + "name": "host", + "image": {"tag": {"name": "env", "value": "prod"}}, + } + result = provider.load_image(req) + provider.ec2.images.filter.assert_called_once_with( + Filters=[{"Name": "tag:env", "Values": ["prod"]}] + ) + assert result is mock_ami + + def test_ami_id_load_still_works(self, provider): + mock_ami = MockAMI("ami-direct", creation_date="2024-01-01T00:00:00.000Z") + provider.ec2.images.filter.return_value = [mock_ami] + + req = {"name": "host", "image": "ami-direct"} + result = provider.load_image(req) + provider.ec2.images.filter.assert_called_once_with( + Filters=[{"Name": "image-id", "Values": ["ami-direct"]}] + ) + assert result is mock_ami diff --git a/tests/unit/test_aws_transformer.py b/tests/unit/test_aws_transformer.py index 9cd9291b..0a539914 100644 --- a/tests/unit/test_aws_transformer.py +++ b/tests/unit/test_aws_transformer.py @@ -76,6 +76,44 @@ def _aws_provisioning_config(user_data=None): ) +class TestAWSTransformerSSMImage: + """Test that SSM image definitions pass through the transformer correctly.""" + + @pytest.mark.asyncio + async def test_ssm_image_def_in_requirement(self): + """SSM image dict should be passed through as-is into the requirement.""" + providers.register(AWS, AWSProvider) + transformer = MockedAWSTransformer() + + ssm_path = ( + "/aws/service/ami-windows-latest/Windows_Server-2022-English-Full-Base" + ) + ssm_image = {"ssm": ssm_path} + aws_cfg = { + "images": { + "win-2022": ssm_image, + "rhel-8.5": "ami-rhel-8-5", + }, + "flavors": {"default": "t2.nano"}, + "keypair": "mrack-keypair.pem", + "security_group": "sg-something", + "security_groups": ["sg-something"], + "credentials_file": "aws.key", + "profile": "default", + "spot": True, + "instance_tags": {"Name": "mrack-runner"}, + } + config = ProvisioningConfig( + {"aws": aws_cfg, "users": {"win-2022": "Administrator"}} + ) + hosts = [_host("ad", "ad", "ad", "win-2022")] + metadata = {"domains": [{"name": DOMAIN, "type": "mixed", "hosts": hosts}]} + await transformer.init(config, metadata) + + req = transformer.create_host_requirement(hosts[0]) + assert req["image"] == ssm_image + + class TestAWSTransformerUserData: """Test the AWS Transformer's user_data handling."""