Skip to content

Commit 9f3b48f

Browse files
authored
Mtoledo/add interruptible parameter (#86)
* add interruptible field [wip] * interruptible to node metadata * interruptible default to false and upd tests * add interruptible to workflow * upd idl * upd tests * add interruptible to tasks * upd python and dynamic task * upd taskmetadata * upd workflow closure test * upd version * upd from_idl and add tests * upd parameterizers * dummy commit to retrigger build * udp workflow metadata * comment out tests * merge conflict
1 parent fd17913 commit 9f3b48f

File tree

19 files changed

+219
-95
lines changed

19 files changed

+219
-95
lines changed

flytekit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import absolute_import
22
import flytekit.plugins
33

4-
__version__ = '0.6.0b1'
4+
__version__ = '0.6.0b2'

flytekit/common/tasks/hive_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
task_type,
3636
discovery_version,
3737
retries,
38+
interruptible,
3839
deprecated,
3940
storage_request,
4041
cpu_request,
@@ -71,7 +72,7 @@ def __init__(
7172
:param dict[Text, Text] environment:
7273
"""
7374
self._task_function = task_function
74-
super(SdkHiveTask, self).__init__(task_function, task_type, discovery_version, retries, deprecated,
75+
super(SdkHiveTask, self).__init__(task_function, task_type, discovery_version, retries, interruptible, deprecated,
7576
storage_request, cpu_request, gpu_request, memory_request, storage_limit,
7677
cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, {})
7778
self._validate_task_parameters(cluster_label, tags)

flytekit/common/tasks/sdk_dynamic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
task_type,
5050
discovery_version,
5151
retries,
52+
interruptible,
5253
deprecated,
5354
storage_request,
5455
cpu_request,
@@ -70,6 +71,7 @@ def __init__(
7071
:param Text task_type: string describing the task type
7172
:param Text discovery_version: string describing the version for task discovery purposes
7273
:param int retries: Number of retries to attempt
74+
:param bool interruptible: Whether or not task is interruptible
7375
:param Text deprecated:
7476
:param Text storage_request:
7577
:param Text cpu_request:
@@ -87,7 +89,7 @@ def __init__(
8789
:param dict[Text, T] custom:
8890
"""
8991
super(SdkDynamicTask, self).__init__(
90-
task_function, task_type, discovery_version, retries, deprecated,
92+
task_function, task_type, discovery_version, retries, interruptible, deprecated,
9193
storage_request, cpu_request, gpu_request, memory_request, storage_limit,
9294
cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom)
9395

flytekit/common/tasks/sdk_runnable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def __init__(
164164
task_type,
165165
discovery_version,
166166
retries,
167+
interruptible,
167168
deprecated,
168169
storage_request,
169170
cpu_request,
@@ -183,6 +184,7 @@ def __init__(
183184
:param Text task_type: string describing the task type
184185
:param Text discovery_version: string describing the version for task discovery purposes
185186
:param int retries: Number of retries to attempt
187+
:param bool interruptible: Specify whether task is interruptible
186188
:param Text deprecated:
187189
:param Text storage_request:
188190
:param Text cpu_request:
@@ -210,6 +212,7 @@ def __init__(
210212
),
211213
timeout,
212214
_literal_models.RetryStrategy(retries),
215+
interruptible,
213216
discovery_version,
214217
deprecated
215218
),

flytekit/common/tasks/sidecar_task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self,
2626
task_type,
2727
discovery_version,
2828
retries,
29+
interruptible,
2930
deprecated,
3031
storage_request,
3132
cpu_request,
@@ -56,6 +57,7 @@ def __init__(self,
5657
task_type,
5758
discovery_version,
5859
retries,
60+
interruptible,
5961
deprecated,
6062
storage_request,
6163
cpu_request,

flytekit/common/tasks/spark_task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
task_type,
5858
discovery_version,
5959
retries,
60+
interruptible,
6061
deprecated,
6162
discoverable,
6263
timeout,
@@ -69,6 +70,7 @@ def __init__(
6970
:param Text task_type: string describing the task type
7071
:param Text discovery_version: string describing the version for task discovery purposes
7172
:param int retries: Number of retries to attempt
73+
:param bool interruptible: Whether or not task is interruptible
7274
:param Text deprecated:
7375
:param bool discoverable:
7476
:param datetime.timedelta timeout:
@@ -92,6 +94,7 @@ def __init__(
9294
task_type,
9395
discovery_version,
9496
retries,
97+
interruptible,
9598
deprecated,
9699
"",
97100
"",

flytekit/common/tasks/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __call__(self, *args, **input_map):
123123
# TODO: Remove DEADBEEF
124124
return _nodes.SdkNode(
125125
id=None,
126-
metadata=_workflow_model.NodeMetadata("DEADBEEF", self.metadata.timeout, self.metadata.retries),
126+
metadata=_workflow_model.NodeMetadata("DEADBEEF", self.metadata.timeout, self.metadata.retries, self.metadata.interruptible),
127127
bindings=sorted(bindings, key=lambda b: b.var),
128128
upstream_nodes=upstream_nodes,
129129
sdk_task=self

flytekit/common/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self, inputs, outputs, nodes, id=None, metadata=None, interface=Non
131131
super(SdkWorkflow, self).__init__(
132132
id=id,
133133
metadata=metadata,
134+
metadata_defaults=_workflow_models.WorkflowMetadataDefaults(),
134135
interface=interface,
135136
nodes=nodes,
136137
outputs=output_bindings,
@@ -255,6 +256,7 @@ def promote_from_model(cls, base_model, sub_workflows=None, tasks=None):
255256
inputs=None, outputs=None, nodes=list(node_map.values()),
256257
id=_identifier.Identifier.promote_from_model(base_model.id),
257258
metadata=base_model.metadata,
259+
metadata_defaults=base_model.metadata_defaults,
258260
interface=_interface.TypedInterface.promote_from_model(base_model.interface),
259261
output_bindings=base_model.outputs,
260262
)

flytekit/contrib/sensors/task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _execute_user_code(self, context, inputs):
2222
def sensor_task(
2323
_task_function=None,
2424
retries=0,
25+
interruptible=None,
2526
deprecated='',
2627
storage_request=None,
2728
cpu_request=None,
@@ -57,6 +58,7 @@ def my_task(wf_params):
5758
.. note::
5859
If retries > 0, the task must be able to recover from any remote state created within the user code. It is
5960
strongly recommended that tasks are written to be idempotent.
61+
:param bool interruptible: Specify whether task is interruptible
6062
:param Text deprecated: [optional] string that should be provided if this task is deprecated. The string
6163
will be logged as a warning so it should contain information regarding how to update to a newer task.
6264
:param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space
@@ -99,6 +101,7 @@ def wrapper(fn):
99101
task_function=fn,
100102
task_type=_common_constants.SdkTaskType.SENSOR_TASK,
101103
retries=retries,
104+
interruptible=interruptible,
102105
deprecated=deprecated,
103106
storage_request=storage_request,
104107
cpu_request=cpu_request,

flytekit/models/core/workflow.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def from_flyte_idl(cls, pb2_objct):
148148

149149
class NodeMetadata(_common.FlyteIdlEntity):
150150

151-
def __init__(self, name, timeout, retries):
151+
def __init__(self, name, timeout, retries, interruptible=False):
152152
"""
153153
Defines extra information about the Node.
154154
@@ -159,6 +159,7 @@ def __init__(self, name, timeout, retries):
159159
self._name = name
160160
self._timeout = timeout
161161
self._retries = retries
162+
self._interruptible = interruptible
162163

163164
@property
164165
def name(self):
@@ -181,11 +182,18 @@ def retries(self):
181182
"""
182183
return self._retries
183184

185+
@property
186+
def interruptible(self):
187+
"""
188+
:rtype: flytekit.models.
189+
"""
190+
return self._interruptible
191+
184192
def to_flyte_idl(self):
185193
"""
186194
:rtype: flyteidl.core.workflow_pb2.NodeMetadata
187195
"""
188-
node_metadata = _core_workflow.NodeMetadata(name=self.name, retries=self.retries.to_flyte_idl())
196+
node_metadata = _core_workflow.NodeMetadata(name=self.name, retries=self.retries.to_flyte_idl(), interruptible=self.interruptible)
189197
node_metadata.timeout.FromTimedelta(self.timeout)
190198
return node_metadata
191199

@@ -458,10 +466,34 @@ def from_flyte_idl(cls, pb2_object):
458466
"""
459467
return cls()
460468

469+
class WorkflowMetadataDefaults(_common.FlyteIdlEntity):
470+
471+
def __init__(self, interruptible=None):
472+
"""
473+
Metadata Defaults for the workflow.
474+
"""
475+
self.interruptible_ = interruptible
476+
477+
def to_flyte_idl(self):
478+
"""
479+
:rtype: flyteidl.core.workflow_pb2.WorkflowMetadataDefaults
480+
"""
481+
return _core_workflow.WorkflowMetadataDefaults(
482+
interruptible=self.interruptible_
483+
)
484+
485+
@classmethod
486+
def from_flyte_idl(cls, pb2_object):
487+
"""
488+
:param flyteidl.core.workflow_pb2.WorkflowMetadataDefaults pb2_object:
489+
:rtype: WorkflowMetadata
490+
"""
491+
return cls(interruptible=pb2_object.interruptible)
492+
461493

462494
class WorkflowTemplate(_common.FlyteIdlEntity):
463495

464-
def __init__(self, id, metadata, interface, nodes, outputs, failure_node=None):
496+
def __init__(self, id, metadata, metadata_defaults, interface, nodes, outputs, failure_node=None):
465497
"""
466498
A workflow template encapsulates all the task, branch, and subworkflow nodes to run a statically analyzable,
467499
directed acyclic graph. It contains also metadata that tells the system how to execute the workflow (i.e.
@@ -470,6 +502,7 @@ def __init__(self, id, metadata, interface, nodes, outputs, failure_node=None):
470502
:param flytekit.models.core.identifier.Identifier id: This is an autogenerated id by the system. The id is
471503
globally unique across Flyte.
472504
:param WorkflowMetadata metadata: This contains information on how to run the workflow.
505+
:param WorkflowMetadataDefaults metadata_defaults: This contains the default information on how to run the workflow.
473506
:param flytekit.models.interface.TypedInterface interface: Defines a strongly typed interface for the
474507
Workflow (inputs, outputs). This can include some optional parameters.
475508
:param list[Node] nodes: A list of nodes. In addition, "globals" is a special reserved node id that
@@ -485,6 +518,7 @@ def __init__(self, id, metadata, interface, nodes, outputs, failure_node=None):
485518
"""
486519
self._id = id
487520
self._metadata = metadata
521+
self._metadata_defaults = metadata_defaults
488522
self._interface = interface
489523
self._nodes = nodes
490524
self._outputs = outputs
@@ -506,6 +540,14 @@ def metadata(self):
506540
"""
507541
return self._metadata
508542

543+
@property
544+
def metadata_defaults(self):
545+
"""
546+
This contains information on how to run the workflow.
547+
:rtype: WorkflowMetadataDefaults
548+
"""
549+
return self._metadata_defaults
550+
509551
@property
510552
def interface(self):
511553
"""
@@ -552,6 +594,7 @@ def to_flyte_idl(self):
552594
return _core_workflow.WorkflowTemplate(
553595
id=self.id.to_flyte_idl(),
554596
metadata=self.metadata.to_flyte_idl(),
597+
metadata_defaults=self.metadata_defaults.to_flyte_idl(),
555598
interface=self.interface.to_flyte_idl(),
556599
nodes=[n.to_flyte_idl() for n in self.nodes],
557600
outputs=[o.to_flyte_idl() for o in self.outputs],
@@ -567,6 +610,7 @@ def from_flyte_idl(cls, pb2_object):
567610
return cls(
568611
id=_identifier.Identifier.from_flyte_idl(pb2_object.id),
569612
metadata=WorkflowMetadata.from_flyte_idl(pb2_object.metadata),
613+
metadata_defaults=WorkflowMetadataDefaults.from_flyte_idl(pb2_object.metadata_defaults),
570614
interface=_interface.TypedInterface.from_flyte_idl(pb2_object.interface),
571615
nodes=[Node.from_flyte_idl(n) for n in pb2_object.nodes],
572616
outputs=[_Binding.from_flyte_idl(b) for b in pb2_object.outputs],

0 commit comments

Comments
 (0)