Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,13 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
}
kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())

step_args = " ".join(util.dict_to_cli_options(kwargs))
# For multinode, create modified kwargs for command construction only
num_parallel = num_parallel or 0
step_kwargs = kwargs.copy()
if num_parallel and num_parallel > 1:
step_kwargs["task_id"] = f"{kwargs['task_id']}[NODE-INDEX]"

step_args = " ".join(util.dict_to_cli_options(step_kwargs))
if num_parallel and num_parallel > 1:
# For multinode, we need to add a placeholder that can be mutated by the caller
step_args += " [multinode-args]"
Expand All @@ -270,15 +275,26 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
retry_deco[0].attributes.get("minutes_between_retries", 1)
)

# Set batch attributes
# Set batch attributes - use modified task_id for multinode to ensure MF_PATHSPEC has placeholder
task_spec_task_id = (
step_kwargs["task_id"] if num_parallel > 1 else kwargs["task_id"]
)
task_spec = {
"flow_name": ctx.obj.flow.name,
"step_name": step_name,
"run_id": kwargs["run_id"],
"task_id": task_spec_task_id,
"retry_count": str(retry_count),
}
# Keep attrs clean with original task_id for metadata
main_task_spec = {
"flow_name": ctx.obj.flow.name,
"step_name": step_name,
"run_id": kwargs["run_id"],
"task_id": kwargs["task_id"],
"retry_count": str(retry_count),
}
attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()}
attrs = {"metaflow.%s" % k: v for k, v in main_task_spec.items()}
attrs["metaflow.user"] = util.get_username()
attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[
"metaflow_version"
Expand Down
20 changes: 8 additions & 12 deletions metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,21 @@ def execute(self):
commands = self.payload["containerOverrides"]["command"][-1]
# add split-index as this worker is also an ubf_task
commands = commands.replace("[multinode-args]", "--split-index 0")
# For main node, remove the placeholder since it keeps the original task ID
commands = commands.replace("[NODE-INDEX]", "")
main_task_override["command"][-1] = commands

# secondary tasks
secondary_task_container_override = copy.deepcopy(
self.payload["containerOverrides"]
)
secondary_commands = self.payload["containerOverrides"]["command"][-1]
# other tasks do not have control- prefix, and have the split id appended to the task -id
secondary_commands = secondary_commands.replace(
self._task_id,
self._task_id.replace("control-", "")
+ "-node-$AWS_BATCH_JOB_NODE_INDEX",
)
secondary_commands = secondary_commands.replace(
"ubf_control",
"ubf_task",
)
secondary_commands = secondary_commands.replace(
"[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX"
# For secondary nodes: remove "control-" prefix and replace placeholders
secondary_commands = (
secondary_commands.replace("control-", "")
.replace("[NODE-INDEX]", "-node-$AWS_BATCH_JOB_NODE_INDEX")
.replace("ubf_control", "ubf_task")
.replace("[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX")
)

secondary_task_container_override["command"][-1] = secondary_commands
Expand Down