|
| 1 | +from typing import ( |
| 2 | + cast, |
| 3 | + Optional, |
| 4 | +) |
| 5 | + |
| 6 | +from gxformat2.model import ( |
| 7 | + get_native_step_type, |
| 8 | + pop_connect_from_step_dict, |
| 9 | + setup_connected_values, |
| 10 | + steps_as_list, |
| 11 | +) |
| 12 | + |
| 13 | +from galaxy.tool_util.parameters import ( |
| 14 | + ConditionalParameterModel, |
| 15 | + ConditionalWhen, |
| 16 | + flat_state_path, |
| 17 | + keys_starting_with, |
| 18 | + repeat_inputs_to_array, |
| 19 | + RepeatParameterModel, |
| 20 | + ToolParameterT, |
| 21 | + validate_explicit_conditional_test_value, |
| 22 | + WorkflowStepLinkedToolState, |
| 23 | + WorkflowStepToolState, |
| 24 | +) |
| 25 | +from galaxy.tool_util_models import ParsedTool |
| 26 | +from ._types import ( |
| 27 | + Format2StepDict, |
| 28 | + Format2WorkflowDict, |
| 29 | + GetToolInfo, |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +def validate_workflow_format2(workflow_dict: Format2WorkflowDict, get_tool_info: GetToolInfo): |
| 34 | + steps = steps_as_list(workflow_dict) |
| 35 | + for step in steps: |
| 36 | + validate_step_format2(step, get_tool_info) |
| 37 | + |
| 38 | + |
| 39 | +def validate_step_format2(step_dict: Format2StepDict, get_tool_info: GetToolInfo): |
| 40 | + step_type = get_native_step_type(step_dict) |
| 41 | + if step_type != "tool": |
| 42 | + return |
| 43 | + tool_id = cast(str, step_dict.get("tool_id")) |
| 44 | + tool_version: Optional[str] = cast(Optional[str], step_dict.get("tool_version")) |
| 45 | + parsed_tool = get_tool_info.get_tool_info(tool_id, tool_version) |
| 46 | + if parsed_tool is not None: |
| 47 | + validate_step_against(step_dict, parsed_tool) |
| 48 | + |
| 49 | + |
| 50 | +def validate_step_against(step_dict: Format2StepDict, parsed_tool: ParsedTool): |
| 51 | + source_tool_state_model = WorkflowStepToolState.parameter_model_for(parsed_tool.inputs) |
| 52 | + linked_tool_state_model = WorkflowStepLinkedToolState.parameter_model_for(parsed_tool.inputs) |
| 53 | + contains_format2_state = "state" in step_dict |
| 54 | + contains_native_state = "tool_state" in step_dict |
| 55 | + if contains_format2_state: |
| 56 | + assert source_tool_state_model |
| 57 | + source_tool_state_model.model_validate(step_dict["state"]) |
| 58 | + if not contains_native_state: |
| 59 | + if not contains_format2_state: |
| 60 | + step_dict["state"] = {} |
| 61 | + # setup links and then validate against model... |
| 62 | + linked_step = merge_inputs(step_dict, parsed_tool) |
| 63 | + linked_tool_state_model.model_validate(linked_step["state"]) |
| 64 | + |
| 65 | + |
| 66 | +def merge_inputs(step_dict: Format2StepDict, parsed_tool: ParsedTool) -> Format2StepDict: |
| 67 | + connect = pop_connect_from_step_dict(step_dict) |
| 68 | + step_dict = setup_connected_values(step_dict, connect) |
| 69 | + tool_inputs = parsed_tool.inputs |
| 70 | + |
| 71 | + state_at_level = step_dict["state"] |
| 72 | + |
| 73 | + for tool_input in tool_inputs: |
| 74 | + _merge_into_state(connect, tool_input, state_at_level) |
| 75 | + |
| 76 | + for key in connect: |
| 77 | + raise Exception(f"Failed to find parameter definition matching workflow linked key {key}") |
| 78 | + return step_dict |
| 79 | + |
| 80 | + |
| 81 | +def _merge_into_state( |
| 82 | + connect, tool_input: ToolParameterT, state: dict, prefix: Optional[str] = None, branch_connect=None |
| 83 | +): |
| 84 | + if branch_connect is None: |
| 85 | + branch_connect = connect |
| 86 | + |
| 87 | + name = tool_input.name |
| 88 | + parameter_type = tool_input.parameter_type |
| 89 | + state_path = flat_state_path(name, prefix) |
| 90 | + if parameter_type == "gx_conditional": |
| 91 | + conditional_state = state.get(name, {}) |
| 92 | + if name not in state: |
| 93 | + state[name] = conditional_state |
| 94 | + |
| 95 | + conditional = cast(ConditionalParameterModel, tool_input) |
| 96 | + when: ConditionalWhen = _select_which_when(conditional, conditional_state) |
| 97 | + test_parameter = conditional.test_parameter |
| 98 | + conditional_connect = keys_starting_with(branch_connect, state_path) |
| 99 | + _merge_into_state( |
| 100 | + connect, test_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect |
| 101 | + ) |
| 102 | + for when_parameter in when.parameters: |
| 103 | + _merge_into_state( |
| 104 | + connect, when_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect |
| 105 | + ) |
| 106 | + elif parameter_type == "gx_repeat": |
| 107 | + repeat_state_array = state.get(name, []) |
| 108 | + repeat = cast(RepeatParameterModel, tool_input) |
| 109 | + repeat_instance_connects = repeat_inputs_to_array(state_path, connect) |
| 110 | + for i, repeat_instance_connect in enumerate(repeat_instance_connects): |
| 111 | + while len(repeat_state_array) <= i: |
| 112 | + repeat_state_array.append({}) |
| 113 | + |
| 114 | + repeat_instance_prefix = f"{state_path}_{i}" |
| 115 | + for repeat_parameter in repeat.parameters: |
| 116 | + _merge_into_state( |
| 117 | + connect, |
| 118 | + repeat_parameter, |
| 119 | + repeat_state_array[i], |
| 120 | + prefix=repeat_instance_prefix, |
| 121 | + branch_connect=repeat_instance_connect, |
| 122 | + ) |
| 123 | + if repeat_state_array and name not in state: |
| 124 | + state[name] = repeat_state_array |
| 125 | + else: |
| 126 | + if state_path in branch_connect: |
| 127 | + state[name] = {"__class__": "ConnectedValue"} |
| 128 | + del connect[state_path] |
| 129 | + |
| 130 | + |
| 131 | +def _select_which_when(conditional: ConditionalParameterModel, state: dict) -> ConditionalWhen: |
| 132 | + test_parameter = conditional.test_parameter |
| 133 | + test_parameter_name = test_parameter.name |
| 134 | + explicit_test_value = state.get(test_parameter_name) |
| 135 | + test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value) |
| 136 | + for when in conditional.whens: |
| 137 | + if test_value is None and when.is_default_when: |
| 138 | + return when |
| 139 | + elif test_value == when.discriminator: |
| 140 | + return when |
| 141 | + else: |
| 142 | + raise Exception(f"Invalid conditional test value ({explicit_test_value}) for parameter ({test_parameter_name})") |
0 commit comments