Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion agent/skyhook-agent/src/skyhook_agent/chroot_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,45 @@
import shutil


def _get_process_env(container_env: dict, skyhook_env: dict, chroot_env: dict):
# Set this first with the container environment.
# We need to do this because the skyhook package could set any env var they want so that needs
# to get replicated down. BUT we are in distroless so we then need to overwrite with the chroot environment
# so things like path/user resolution work.
process_env = dict(container_env)
# Overwrite the container environment with the chroot environment
process_env.update(chroot_env)
# Inject the skyhook environment variables
process_env.update(skyhook_env)
return process_env

def _get_chroot_env():
results = subprocess.run(["env"], capture_output=True, text=True)
env = {}
for line in results.stdout.split("\n"):
if "=" in line:
k, v = line.split("=", 1)
env[k] = v
return env

def chroot_exec(config: dict, chroot_dir: str):
cmds = config["cmd"]
no_chmod = config["no_chmod"]
skyhook_env = config["env"]

# Capture container environment before chroot
container_env = dict(os.environ)

if chroot_dir != "local":
os.chroot(chroot_dir)
os.chdir("/")
try:
if not no_chmod:
# chmod +x the step
os.chmod(cmds[0], os.stat(cmds[0]).st_mode | stat.S_IXGRP | stat.S_IXUSR | stat.S_IXOTH)
subprocess.run(cmds, check=True)

process_env = _get_process_env(container_env, skyhook_env, _get_chroot_env())
subprocess.run(cmds, check=True, env=process_env)
except:
raise

Expand Down
12 changes: 7 additions & 5 deletions agent/skyhook-agent/src/skyhook_agent/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def _stream_process(
break


async def tee(chroot_dir: str, cmd: List[str], stdout_sink_path: str, stderr_sink_path: str, write_cmds=False, no_chmod=False, **kwargs):
async def tee(chroot_dir: str, cmd: List[str], stdout_sink_path: str, stderr_sink_path: str, write_cmds=False, no_chmod=False, env: dict[str, str] = {}, **kwargs):
"""
Run the cmd in a subprocess and keep the stream of stdout/stderr and merge both into
the sink_path as a log.
Expand All @@ -142,7 +142,7 @@ async def tee(chroot_dir: str, cmd: List[str], stdout_sink_path: str, stderr_sin
sys.stdout.write(" ".join(cmd) + "\n")
stdout_sink_f.write(" ".join(cmd) + "\n")
with tempfile.NamedTemporaryFile(mode="w", delete=True) as f:
f.write(json.dumps({"cmd": cmd, "no_chmod": no_chmod}))
f.write(json.dumps({"cmd": cmd, "no_chmod": no_chmod, "env": env}))
f.flush()

# Run the special chroot_exec.py script to chroot into the directory and run the command
Expand Down Expand Up @@ -220,7 +220,7 @@ def set_flag(flag_file: str, msg: str = "") -> None:
f.write(msg)


def _run(chroot_dir: str, cmds: list[str], log_path: str, write_cmds=False, no_chmod=False,**kwargs) -> int:
def _run(chroot_dir: str, cmds: list[str], log_path: str, write_cmds=False, no_chmod=False, env: dict[str, str] = {}, **kwargs) -> int:
"""
Synchronous wrapper around the tee command to have logs written to disk
"""
Expand All @@ -234,6 +234,7 @@ def _run(chroot_dir: str, cmds: list[str], log_path: str, write_cmds=False, no_c
f"{log_path}.err",
write_cmds=write_cmds,
no_chmod=no_chmod,
env=env,
**kwargs
)
)
Expand Down Expand Up @@ -284,10 +285,11 @@ def run_step(
time.sleep(1)
log_file = get_log_file(step_path, copy_dir, config_data, chroot_dir)

# Make sure to include the original environment here or else things like path resolution dont work
env = dict(**os.environ)
# Compile additional environment variables
env = {}
env.update(step.env)
env.update({"STEP_ROOT": get_host_path_for_steps(copy_dir), "SKYHOOK_DIR": copy_dir})

return_code = _run(
chroot_dir,
[step_path, *step.arguments],
Expand Down
233 changes: 233 additions & 0 deletions agent/skyhook-agent/tests/test_chroot_exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest import mock
from skyhook_agent.chroot_exec import _get_process_env, _get_chroot_env


class TestChrootExec(unittest.TestCase):

def test_get_process_env_basic_functionality(self):
"""Test _get_process_env with non-overlapping keys"""
container_env = {"CONTAINER_VAR": "container_value"}
chroot_env = {"CHROOT_VAR": "chroot_value"}
skyhook_env = {"SKYHOOK_VAR": "skyhook_value"}

result = _get_process_env(container_env, skyhook_env, chroot_env)

expected = {
"CONTAINER_VAR": "container_value",
"CHROOT_VAR": "chroot_value",
"SKYHOOK_VAR": "skyhook_value"
}
self.assertEqual(result, expected)

def test_get_process_env_chroot_overrides_container(self):
"""Test that chroot_env overrides container_env for same keys"""
container_env = {"SAME_VAR": "container_value", "CONTAINER_VAR": "container_value"}
chroot_env = {"SAME_VAR": "chroot_value", "CHROOT_VAR": "chroot_value"}
skyhook_env = {"SKYHOOK_VAR": "skyhook_value"}

result = _get_process_env(container_env, skyhook_env, chroot_env)

expected = {
"SAME_VAR": "chroot_value", # chroot overrides container
"CONTAINER_VAR": "container_value",
"CHROOT_VAR": "chroot_value",
"SKYHOOK_VAR": "skyhook_value"
}
self.assertEqual(result, expected)

def test_get_process_env_skyhook_overrides_all(self):
"""Test that skyhook_env has highest priority and overrides both chroot and container"""
container_env = {"SAME_VAR": "container_value", "CONTAINER_VAR": "container_value"}
chroot_env = {"SAME_VAR": "chroot_value", "CHROOT_VAR": "chroot_value"}
skyhook_env = {"SAME_VAR": "skyhook_value", "SKYHOOK_VAR": "skyhook_value"}

result = _get_process_env(container_env, skyhook_env, chroot_env)

expected = {
"SAME_VAR": "skyhook_value", # skyhook overrides both chroot and container
"CONTAINER_VAR": "container_value",
"CHROOT_VAR": "chroot_value",
"SKYHOOK_VAR": "skyhook_value"
}
self.assertEqual(result, expected)

def test_get_process_env_with_empty_dicts(self):
"""Test _get_process_env with empty dictionaries"""
result = _get_process_env({}, {}, {})
self.assertEqual(result, {})

# Test with only one dict having values
container_env = {"VAR": "value"}
result = _get_process_env(container_env, {}, {})
self.assertEqual(result, {"VAR": "value"})

chroot_env = {"VAR": "value"}
result = _get_process_env({}, {}, chroot_env)
self.assertEqual(result, {"VAR": "value"})

skyhook_env = {"VAR": "value"}
result = _get_process_env({}, skyhook_env, {})
self.assertEqual(result, {"VAR": "value"})

def test_get_process_env_precedence_order(self):
"""Test complete precedence order: skyhook > chroot > container"""
container_env = {
"PATH": "/container/path",
"HOME": "/container/home",
"USER": "container_user",
"ONLY_CONTAINER": "container_only"
}
chroot_env = {
"PATH": "/chroot/path",
"HOME": "/chroot/home",
"ONLY_CHROOT": "chroot_only"
}
skyhook_env = {
"PATH": "/skyhook/path",
"ONLY_SKYHOOK": "skyhook_only"
}

result = _get_process_env(container_env, skyhook_env, chroot_env)

expected = {
"PATH": "/skyhook/path", # skyhook wins
"HOME": "/chroot/home", # chroot wins over container
"USER": "container_user", # only in container
"ONLY_CONTAINER": "container_only",
"ONLY_CHROOT": "chroot_only",
"ONLY_SKYHOOK": "skyhook_only"
}
self.assertEqual(result, expected)

def test_get_process_env_does_not_modify_input_dicts(self):
"""Test that input dictionaries are not modified"""
container_env = {"VAR": "container"}
chroot_env = {"VAR": "chroot"}
skyhook_env = {"VAR": "skyhook"}

# Keep original references
original_container = container_env.copy()
original_chroot = chroot_env.copy()
original_skyhook = skyhook_env.copy()

result = _get_process_env(container_env, skyhook_env, chroot_env)

# Verify input dicts weren't modified
self.assertEqual(container_env, original_container)
self.assertEqual(chroot_env, original_chroot)
self.assertEqual(skyhook_env, original_skyhook)

# Verify result is correct
self.assertEqual(result, {"VAR": "skyhook"})

@mock.patch('skyhook_agent.chroot_exec.subprocess.run')
def test_get_chroot_env_basic_functionality(self, mock_subprocess):
"""Test _get_chroot_env with typical environment output"""
mock_result = mock.MagicMock()
mock_result.stdout = "PATH=/usr/bin:/bin\nHOME=/root\nUSER=root\n"
mock_subprocess.return_value = mock_result

result = _get_chroot_env()

expected = {
"PATH": "/usr/bin:/bin",
"HOME": "/root",
"USER": "root"
}
self.assertEqual(result, expected)
mock_subprocess.assert_called_once_with(["env"], capture_output=True, text=True)

@mock.patch('skyhook_agent.chroot_exec.subprocess.run')
def test_get_chroot_env_with_multiple_equals(self, mock_subprocess):
"""Test _get_chroot_env correctly handles lines with multiple '=' characters"""
mock_result = mock.MagicMock()
mock_result.stdout = "VAR1=value=with=equals\nVAR2=simple_value\n"
mock_subprocess.return_value = mock_result

result = _get_chroot_env()

expected = {
"VAR1": "value=with=equals", # Should split only on first =
"VAR2": "simple_value"
}
self.assertEqual(result, expected)

@mock.patch('skyhook_agent.chroot_exec.subprocess.run')
def test_get_chroot_env_ignores_lines_without_equals(self, mock_subprocess):
"""Test _get_chroot_env ignores lines that don't contain '='"""
mock_result = mock.MagicMock()
mock_result.stdout = "PATH=/usr/bin\ninvalid_line_no_equals\nHOME=/root\n\n"
mock_subprocess.return_value = mock_result

result = _get_chroot_env()

expected = {
"PATH": "/usr/bin",
"HOME": "/root"
}
self.assertEqual(result, expected)

@mock.patch('skyhook_agent.chroot_exec.subprocess.run')
def test_get_chroot_env_with_empty_output(self, mock_subprocess):
"""Test _get_chroot_env with empty subprocess output"""
mock_result = mock.MagicMock()
mock_result.stdout = ""
mock_subprocess.return_value = mock_result

result = _get_chroot_env()

self.assertEqual(result, {})
mock_subprocess.assert_called_once_with(["env"], capture_output=True, text=True)

@mock.patch('skyhook_agent.chroot_exec.subprocess.run')
def test_get_chroot_env_with_empty_values(self, mock_subprocess):
"""Test _get_chroot_env handles environment variables with empty values"""
mock_result = mock.MagicMock()
mock_result.stdout = "EMPTY_VAR=\nNORM_VAR=value\nANOTHER_EMPTY=\n"
mock_subprocess.return_value = mock_result

result = _get_chroot_env()

expected = {
"EMPTY_VAR": "",
"NORM_VAR": "value",
"ANOTHER_EMPTY": ""
}
self.assertEqual(result, expected)

@mock.patch('skyhook_agent.chroot_exec.subprocess.run')
def test_get_chroot_env_with_whitespace_and_special_chars(self, mock_subprocess):
"""Test _get_chroot_env handles values with whitespace and special characters"""
mock_result = mock.MagicMock()
mock_result.stdout = "VAR_WITH_SPACES=value with spaces\nSPECIAL_CHARS=!@#$%^&*()\nPATH=/usr/bin:/bin\n"
mock_subprocess.return_value = mock_result

result = _get_chroot_env()

expected = {
"VAR_WITH_SPACES": "value with spaces",
"SPECIAL_CHARS": "!@#$%^&*()",
"PATH": "/usr/bin:/bin"
}
self.assertEqual(result, expected)


if __name__ == '__main__':
unittest.main()
16 changes: 8 additions & 8 deletions agent/skyhook-agent/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def test_run_step_replaces_environment_variables(
["copy_dir/skyhook_dir/foo", "a", "foo"],
log_file,
f"{log_file}.err",
env=dict(**os.environ, **{"STEP_ROOT": "copy_dir/skyhook_dir", "FOO": "foo", "SKYHOOK_DIR": "copy_dir"}),
write_cmds=False,
no_chmod=False
no_chmod=False,
env={"STEP_ROOT": "copy_dir/skyhook_dir", "SKYHOOK_DIR": "copy_dir"}
)
]
)
Expand Down Expand Up @@ -690,9 +690,9 @@ def test_from_and_to_version_is_given_to_upgrade_step_as_env_var(self, run_mock,
controller.get_log_file(
f"{controller.get_host_path_for_steps(copy_dir)}/foo", f"/foo", config_data, root_dir
),
env=dict(**os.environ,
**{"PREVIOUS_VERSION": "0.0.9", "CURRENT_VERSION": "1.0.0"},
**{"STEP_ROOT": f"{root_dir}/{copy_dir}/skyhook_dir", "SKYHOOK_DIR": copy_dir})
env=dict(
**{"PREVIOUS_VERSION": "0.0.9", "CURRENT_VERSION": "1.0.0"},
**{"STEP_ROOT": f"{root_dir}/{copy_dir}/skyhook_dir", "SKYHOOK_DIR": copy_dir})
)
])

Expand Down Expand Up @@ -729,9 +729,9 @@ def test_from_and_to_version_is_given_to_upgradestep_class_as_env_var_and_args(s
controller.get_log_file(
f"{controller.get_host_path_for_steps(copy_dir)}/foo", f"/foo", config_data, root_dir
),
env=dict(**os.environ,
**{"PREVIOUS_VERSION": "2024.07.28", "CURRENT_VERSION": "1.0.0"},
**{"STEP_ROOT": f"{root_dir}/{copy_dir}/skyhook_dir", "SKYHOOK_DIR": copy_dir})
env=dict(
**{"PREVIOUS_VERSION": "2024.07.28", "CURRENT_VERSION": "1.0.0"},
**{"STEP_ROOT": f"{root_dir}/{copy_dir}/skyhook_dir", "SKYHOOK_DIR": copy_dir})
)
])

Expand Down