Skip to content

Commit e7d1ed8

Browse files
committed
Workflow conversion and validation things.
1 parent f4126f3 commit e7d1ed8

21 files changed

+848
-5
lines changed

lib/galaxy/tool_util/parameters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
"RepeatParameterModel",
136136
"RawStateDict",
137137
"ValidationFunctionT",
138+
"is_optional",
138139
"validate_against_model",
139140
"validate_internal_job",
140141
"validate_internal_landing_request",
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import Union
2+
3+
from packaging.version import Version
4+
5+
from .version import LegacyVersion
6+
7+
AnyVersionT = Union[LegacyVersion, Version]
8+
9+
10+
__all__ = ["AnyVersionT"]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Abstractions for reasoning about tool state within Galaxy workflows.
2+
3+
Like everything else in galaxy-tool-util, this package should be independent of
4+
Galaxy's runtime. It is meant to provide utilities for reasonsing about tool state
5+
(largely building on the abstractions in galaxy.tool_util.parameters) within the
6+
context of workflows.
7+
"""
8+
9+
from ._types import GetToolInfo
10+
from .convert import (
11+
ConversionValidationFailure,
12+
convert_state_to_format2,
13+
Format2State,
14+
)
15+
from .validation import validate_workflow
16+
17+
__all__ = (
18+
"ConversionValidationFailure",
19+
"convert_state_to_format2",
20+
"GetToolInfo",
21+
"Format2State",
22+
"validate_workflow",
23+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import (
2+
Any,
3+
Dict,
4+
Optional,
5+
Union,
6+
)
7+
8+
from typing_extensions import (
9+
Literal,
10+
Protocol,
11+
)
12+
13+
from galaxy.tool_util_models import ParsedTool
14+
15+
NativeWorkflowDict = Dict[str, Any]
16+
Format2WorkflowDict = Dict[str, Any]
17+
AnyWorkflowDict = Union[NativeWorkflowDict, Format2WorkflowDict]
18+
WorkflowFormat = Literal["gxformat2", "native"]
19+
NativeStepDict = Dict[str, Any]
20+
Format2StepDict = Dict[str, Any]
21+
NativeToolStateDict = Dict[str, Any]
22+
Format2StateDict = Dict[str, Any]
23+
24+
25+
class GetToolInfo(Protocol):
26+
"""An interface for fetching tool information for steps in a workflow."""
27+
28+
def get_tool_info(self, tool_id: str, tool_version: Optional[str]) -> ParsedTool: ...
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import (
2+
Dict,
3+
List,
4+
Optional,
5+
)
6+
7+
from pydantic import (
8+
BaseModel,
9+
Field,
10+
)
11+
12+
from galaxy.tool_util.parameters import ToolParameterT
13+
from galaxy.tool_util_models import ParsedTool
14+
from ._types import (
15+
Format2StateDict,
16+
GetToolInfo,
17+
NativeStepDict,
18+
)
19+
from .validation_format2 import validate_step_against
20+
from .validation_native import (
21+
get_parsed_tool_for_native_step,
22+
native_tool_state,
23+
validate_native_step_against,
24+
)
25+
26+
Format2InputsDictT = Dict[str, str]
27+
28+
29+
class Format2State(BaseModel):
30+
state: Format2StateDict
31+
inputs: Format2InputsDictT = Field(alias="in")
32+
33+
34+
class ConversionValidationFailure(Exception):
35+
pass
36+
37+
38+
def convert_state_to_format2(native_step_dict: NativeStepDict, get_tool_info: GetToolInfo) -> Format2State:
39+
parsed_tool = get_parsed_tool_for_native_step(native_step_dict, get_tool_info)
40+
return convert_state_to_format2_using(native_step_dict, parsed_tool)
41+
42+
43+
def convert_state_to_format2_using(native_step_dict: NativeStepDict, parsed_tool: Optional[ParsedTool]) -> Format2State:
44+
"""Create a "clean" gxformat2 workflow tool state from a native workflow step.
45+
46+
gxformat2 does not know about tool specifications so it cannot reason about the native
47+
tool state attribute and just copies it as is. This native state can be pretty ugly. The purpose
48+
of this function is to build a cleaned up state to replace the gxformat2 copied native tool_state
49+
with that is more readable and has stronger typing by using the tool's inputs to guide
50+
the conversion (the parsed_tool parameter).
51+
52+
This method validates both the native tool state and the resulting gxformat2 tool state
53+
so that we can be more confident the conversion doesn't corrupt the workflow. If no meta
54+
model to validate against is supplied or if either validation fails this method throws
55+
ConversionValidationFailure to signal the caller to just use the native tool state as is
56+
instead of trying to convert it to a cleaner gxformat2 tool state - under the assumption
57+
it is better to have an "ugly" workflow than a corrupted one during conversion.
58+
"""
59+
if parsed_tool is None:
60+
raise ConversionValidationFailure("Could not resolve tool inputs")
61+
try:
62+
validate_native_step_against(native_step_dict, parsed_tool)
63+
except Exception:
64+
raise ConversionValidationFailure(
65+
"Failed to validate native step - not going to convert a tool state that isn't understood"
66+
)
67+
result = _convert_valid_state_to_format2(native_step_dict, parsed_tool)
68+
print(result.dict())
69+
try:
70+
validate_step_against(result.dict(), parsed_tool)
71+
except Exception:
72+
raise ConversionValidationFailure(
73+
"Failed to validate resulting cleaned step - not going to convert to an unvalidated tool state"
74+
)
75+
return result
76+
77+
78+
def _convert_valid_state_to_format2(native_step_dict: NativeStepDict, parsed_tool: ParsedTool) -> Format2State:
79+
format2_state: Format2StateDict = {}
80+
format2_in: Format2InputsDictT = {}
81+
82+
root_tool_state = native_tool_state(native_step_dict)
83+
tool_inputs = parsed_tool.inputs
84+
_convert_state_level(native_step_dict, tool_inputs, root_tool_state, format2_state, format2_in)
85+
return Format2State(
86+
**{
87+
"state": format2_state,
88+
"in": format2_in,
89+
}
90+
)
91+
92+
93+
def _convert_state_level(
94+
step: NativeStepDict,
95+
tool_inputs: List[ToolParameterT],
96+
native_state: dict,
97+
format2_state_at_level: dict,
98+
format2_in: Format2InputsDictT,
99+
prefix: Optional[str] = None,
100+
) -> None:
101+
prefix = prefix or ""
102+
assert prefix is not None
103+
for tool_input in tool_inputs:
104+
_convert_state_at_level(step, tool_input, native_state, format2_state_at_level, format2_in, prefix)
105+
106+
107+
def _convert_state_at_level(
108+
step: NativeStepDict,
109+
tool_input: ToolParameterT,
110+
native_state_at_level: dict,
111+
format2_state_at_level: dict,
112+
format2_in: Format2InputsDictT,
113+
prefix: str,
114+
) -> None:
115+
parameter_type = tool_input.parameter_type
116+
parameter_name = tool_input.name
117+
value = native_state_at_level.get(parameter_name, None)
118+
state_path = parameter_name if prefix is None else f"{prefix}|{parameter_name}"
119+
if parameter_type == "gx_integer":
120+
# check for runtime input
121+
try:
122+
format2_value = int(value) # type: ignore[arg-type]
123+
except ValueError:
124+
raise Exception(f"Failed to convert integer value {value} for parameter {parameter_name}")
125+
format2_state_at_level[parameter_name] = format2_value
126+
elif parameter_type == "gx_data":
127+
input_connections = step.get("input_connections", {})
128+
print(state_path)
129+
print(input_connections)
130+
if state_path in input_connections:
131+
format2_in[state_path] = "placeholder"
132+
else:
133+
pass
134+
# raise NotImplementedError(f"Unhandled parameter type {parameter_type}")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from ._types import (
2+
AnyWorkflowDict,
3+
GetToolInfo,
4+
WorkflowFormat,
5+
)
6+
from .validation_format2 import validate_workflow_format2
7+
from .validation_native import validate_workflow_native
8+
9+
10+
def validate_workflow(workflow_dict: AnyWorkflowDict, get_tool_info: GetToolInfo):
11+
if _format(workflow_dict) == "gxformat2":
12+
validate_workflow_format2(workflow_dict, get_tool_info)
13+
else:
14+
validate_workflow_native(workflow_dict, get_tool_info)
15+
16+
17+
def _format(workflow_dict: AnyWorkflowDict) -> WorkflowFormat:
18+
if workflow_dict.get("a_galaxy_workflow") == "true":
19+
return "native"
20+
else:
21+
return "gxformat2"
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from typing import (
2+
cast,
3+
Optional,
4+
)
5+
6+
from gxformat2.model import (
7+
get_native_step_type,
8+
pop_connect_from_step_dict,
9+
setup_connected_values,
10+
steps_as_list,
11+
)
12+
13+
from galaxy.tool_util.parameters import (
14+
ConditionalParameterModel,
15+
ConditionalWhen,
16+
flat_state_path,
17+
keys_starting_with,
18+
repeat_inputs_to_array,
19+
RepeatParameterModel,
20+
ToolParameterT,
21+
validate_explicit_conditional_test_value,
22+
WorkflowStepLinkedToolState,
23+
WorkflowStepToolState,
24+
)
25+
from galaxy.tool_util_models import ParsedTool
26+
from ._types import (
27+
Format2StepDict,
28+
Format2WorkflowDict,
29+
GetToolInfo,
30+
)
31+
32+
33+
def validate_workflow_format2(workflow_dict: Format2WorkflowDict, get_tool_info: GetToolInfo):
34+
steps = steps_as_list(workflow_dict)
35+
for step in steps:
36+
validate_step_format2(step, get_tool_info)
37+
38+
39+
def validate_step_format2(step_dict: Format2StepDict, get_tool_info: GetToolInfo):
40+
step_type = get_native_step_type(step_dict)
41+
if step_type != "tool":
42+
return
43+
tool_id = cast(str, step_dict.get("tool_id"))
44+
tool_version: Optional[str] = cast(Optional[str], step_dict.get("tool_version"))
45+
parsed_tool = get_tool_info.get_tool_info(tool_id, tool_version)
46+
if parsed_tool is not None:
47+
validate_step_against(step_dict, parsed_tool)
48+
49+
50+
def validate_step_against(step_dict: Format2StepDict, parsed_tool: ParsedTool):
51+
source_tool_state_model = WorkflowStepToolState.parameter_model_for(parsed_tool.inputs)
52+
linked_tool_state_model = WorkflowStepLinkedToolState.parameter_model_for(parsed_tool.inputs)
53+
contains_format2_state = "state" in step_dict
54+
contains_native_state = "tool_state" in step_dict
55+
if contains_format2_state:
56+
assert source_tool_state_model
57+
source_tool_state_model.model_validate(step_dict["state"])
58+
if not contains_native_state:
59+
if not contains_format2_state:
60+
step_dict["state"] = {}
61+
# setup links and then validate against model...
62+
linked_step = merge_inputs(step_dict, parsed_tool)
63+
linked_tool_state_model.model_validate(linked_step["state"])
64+
65+
66+
def merge_inputs(step_dict: Format2StepDict, parsed_tool: ParsedTool) -> Format2StepDict:
67+
connect = pop_connect_from_step_dict(step_dict)
68+
step_dict = setup_connected_values(step_dict, connect)
69+
tool_inputs = parsed_tool.inputs
70+
71+
state_at_level = step_dict["state"]
72+
73+
for tool_input in tool_inputs:
74+
_merge_into_state(connect, tool_input, state_at_level)
75+
76+
for key in connect:
77+
raise Exception(f"Failed to find parameter definition matching workflow linked key {key}")
78+
return step_dict
79+
80+
81+
def _merge_into_state(
82+
connect, tool_input: ToolParameterT, state: dict, prefix: Optional[str] = None, branch_connect=None
83+
):
84+
if branch_connect is None:
85+
branch_connect = connect
86+
87+
name = tool_input.name
88+
parameter_type = tool_input.parameter_type
89+
state_path = flat_state_path(name, prefix)
90+
if parameter_type == "gx_conditional":
91+
conditional_state = state.get(name, {})
92+
if name not in state:
93+
state[name] = conditional_state
94+
95+
conditional = cast(ConditionalParameterModel, tool_input)
96+
when: ConditionalWhen = _select_which_when(conditional, conditional_state)
97+
test_parameter = conditional.test_parameter
98+
conditional_connect = keys_starting_with(branch_connect, state_path)
99+
_merge_into_state(
100+
connect, test_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect
101+
)
102+
for when_parameter in when.parameters:
103+
_merge_into_state(
104+
connect, when_parameter, conditional_state, prefix=state_path, branch_connect=conditional_connect
105+
)
106+
elif parameter_type == "gx_repeat":
107+
repeat_state_array = state.get(name, [])
108+
repeat = cast(RepeatParameterModel, tool_input)
109+
repeat_instance_connects = repeat_inputs_to_array(state_path, connect)
110+
for i, repeat_instance_connect in enumerate(repeat_instance_connects):
111+
while len(repeat_state_array) <= i:
112+
repeat_state_array.append({})
113+
114+
repeat_instance_prefix = f"{state_path}_{i}"
115+
for repeat_parameter in repeat.parameters:
116+
_merge_into_state(
117+
connect,
118+
repeat_parameter,
119+
repeat_state_array[i],
120+
prefix=repeat_instance_prefix,
121+
branch_connect=repeat_instance_connect,
122+
)
123+
if repeat_state_array and name not in state:
124+
state[name] = repeat_state_array
125+
else:
126+
if state_path in branch_connect:
127+
state[name] = {"__class__": "ConnectedValue"}
128+
del connect[state_path]
129+
130+
131+
def _select_which_when(conditional: ConditionalParameterModel, state: dict) -> ConditionalWhen:
132+
test_parameter = conditional.test_parameter
133+
test_parameter_name = test_parameter.name
134+
explicit_test_value = state.get(test_parameter_name)
135+
test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value)
136+
for when in conditional.whens:
137+
if test_value is None and when.is_default_when:
138+
return when
139+
elif test_value == when.discriminator:
140+
return when
141+
else:
142+
raise Exception(f"Invalid conditional test value ({explicit_test_value}) for parameter ({test_parameter_name})")

0 commit comments

Comments
 (0)