Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 97 additions & 18 deletions src/snakemake_software_deployment_plugin_container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
__copyright__ = "Copyright 2025, ben carrillo"
__email__ = "[email protected]"
__license__ = "MIT"
import json
import os.path
import subprocess
import tempfile

from dataclasses import dataclass, field
Expand All @@ -29,6 +31,9 @@
# The mountpoint for the Snakemake working directory inside the container.
SNAKEMAKE_MOUNTPOINT = "/mnt/snakemake"

# Where the source-cache dir is found under the cache folder
SOURCE_CACHE = "snakemake/source-cache"


# ContainerType is an enum that defines the different container types we support.
# If adding new ones, make sure the choice is the same as the command name.
Expand Down Expand Up @@ -76,6 +81,14 @@ class ContainerEnv(EnvBase):
def __post_init__(self) -> None:
self.check()

def _get_image_uri_and_tag(self) -> Iterable[str]:
parts = self.spec.image_uri.split(":")
if len(parts) > 2:
raise WorkflowError("Malformed image URI", self.spec.image_uri)
if len(parts) != 2:
parts += ["latest"]
return parts

# The decorator ensures that the decorated method is only called once
# in case multiple environments of the same kind are created.
@EnvBase.once
Expand All @@ -86,8 +99,6 @@ def _check_service(self) -> bool:
if self.spec.image_uri == "":
raise WorkflowError("Image URI is empty")

# TODO: if we don't get the tag, we should assume :latest

if self.settings.kind not in ContainerType.all():
raise WorkflowError("Invalid container kind")

Expand All @@ -103,11 +114,10 @@ def _check_executable(self):

def decorate_shellcmd(self, cmd: str) -> str:
# TODO pass more options here (extra mount volumes, user etc)
image = ":".join(self._get_image_uri_and_tag())

hostcache = os.path.join(get_appdirs().user_cache_dir, "snakemake/source-cache")
containercache = os.path.join(
SNAKEMAKE_MOUNTPOINT, ".cache/snakemake/source-cache"
)
hostcache = os.path.join(get_appdirs().user_cache_dir, SOURCE_CACHE)
containercache = os.path.join(SNAKEMAKE_MOUNTPOINT, ".cache", SOURCE_CACHE)

if not os.path.exists(hostcache):
hostcache = containercache = tempfile.mkdtemp()
Expand All @@ -130,7 +140,7 @@ def decorate_shellcmd(self, cmd: str) -> str:
hostdir=repr(getcwd()), # TODO: allow to override
hostcache=repr(hostcache),
containercache=repr(containercache),
image_id=self.spec.image_uri,
image_id=image,
shell="/bin/sh",
cmd=cmd.replace("'", r"'\''"),
)
Expand All @@ -145,14 +155,83 @@ def record_hash(self, hash_object) -> None:
hash_object.update(...)

def report_software(self) -> Iterable[SoftwareReport]:
# Report the software contained in the environment. This should be a list of
# snakemake_interface_software_deployment_plugins.SoftwareReport data class.
# Use SoftwareReport.is_secondary = True if the software is just some
# less important technical dependency. This allows Snakemake's report to
# hide those for clarity. In case of containers, it is also valid to
# return the container URI as a "software".
# Return an empty tuple () if no software can be reported.
# TODO: implement.
# Get container URI + hash (assuming we've already executd and fetched the image,
# so that we can get the hash for the image plus the tag)
return ()
uri, tag = self._get_image_uri_and_tag()
image = SoftwareReport(
name=uri,
version=tag,
)

# In addition to the image tag, we also want to include the full image id in the version
# reporting.
# TODO: can move the managers to the initialization to encapsulate backend-specific logic
# TODO: we can retrieve the dereferenced URI from the image repo. But different backends
# have different ways of representing the metadata.
if self.settings.kind == ContainerType.PODMAN:
pm = PodmanManager()
elif self.settings.kind == ContainerType.UDOCKER:
pm = UDockerManager()
full_image_id = pm.inspect_image(uri)
if full_image_id != "":
image.version = f"{image.version}/{full_image_id}"

yield image


class UDockerManager:
cmd = ContainerType.UDOCKER.item_to_choice()

def inspect_image(self, image_id) -> str:
try:
# Run udocker inspect command
result = subprocess.run(
[self.cmd, "inspect", image_id],
capture_output=True,
text=True,
check=True,
)

# Parse the output as JSON
inspect_data = json.loads(result.stdout)

# Extract the hash from rootfs.diff_ids
if "rootfs" in inspect_data and "diff_ids" in inspect_data["rootfs"]:
if len(inspect_data["rootfs"]["diff_ids"]) > 0:
diff_id = inspect_data["rootfs"]["diff_ids"][0]
# Remove sha256: prefix if present
if diff_id.startswith("sha256:"):
return diff_id[7:19] # First 12 chars after prefix
return diff_id[:12]

return "" # Return empty string if hash not found

except (
subprocess.CalledProcessError,
json.JSONDecodeError,
KeyError,
IndexError,
) as e:
print(f"error: failed to extract hash for udocker image {image_id}: {e}")
return ""


class PodmanManager:
cmd = ContainerType.PODMAN.item_to_choice()

def inspect_image(self, image_id) -> str:
try:
result = subprocess.run(
[self.cmd, "inspect", image_id],
capture_output=True,
text=True,
check=True,
)
inspect_data = json.loads(result.stdout)
full_image_id = inspect_data[0]["Id"]
truncated = full_image_id[:12]
return truncated
except subprocess.CalledProcessError as e:
print(f"error: failed to inspect image {image_id}: {e}")
return ""
except (KeyError, IndexError, json.JSONDecodeError) as e:
print(f"error: failed to parse output for image {image_id}: {e}")
return ""
43 changes: 36 additions & 7 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import subprocess as sp
from typing import Optional, Type

import pytest
Expand Down Expand Up @@ -53,6 +54,24 @@ def get_test_cmd(self) -> str:
# with exit code 0 (i.e. without error).
return "/bin/true"

def test_report_software(self, tmp_path):
env = self._get_env(tmp_path)
cmd = self.get_test_cmd()
decorated_cmd = env.managed_decorate_shellcmd(cmd)

# force the run to actually fetch the image
# TODO: there might be a better way to test this after the automatic
# testing has actually been called
sp.run(decorated_cmd, shell=True, executable=self.shell_executable)
rep = tuple(env.report_software())

# check the first software reported, should be the container
# We're reporting version as the tag + the hash of the image
# latest/aded1e1a5b37
assert rep[0].name == "alpine"
assert len(rep[0].version) == 19
assert rep[0].version.startswith("latest/")


# Helper function to check if podman is available
def is_podman_available():
Expand Down Expand Up @@ -88,10 +107,20 @@ def get_test_cmd(self) -> str:
# with exit code 0 (i.e. without error).
return "/bin/true"


# Test that the container is outputting something useful at all
"""
sp.run(
decorated_cmd, shell=True, executable=self.shell_executable
).returncode
"""
# This test is optional; we are interested in peeking beyond the interface
# and make sure we're getting specific information from the container.
def test_report_software(self, tmp_path):
env = self._get_env(tmp_path)
cmd = self.get_test_cmd()
decorated_cmd = env.managed_decorate_shellcmd(cmd)

# force the run to actually fetch the image
sp.run(decorated_cmd, shell=True, executable=self.shell_executable)
rep = tuple(env.report_software())

# check the first software reported, should be the container
# We're reporting version as the tag + the hash of the image
# latest/aded1e1a5b37
assert rep[0].name == "alpine"
assert len(rep[0].version) == 19
assert rep[0].version.startswith("latest/")