Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 77 additions & 32 deletions jumpstarter/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import anyio
import transitions
from anyio.abc import Event
from anyio.abc import Event as EventType
from transitions import EventData
from transitions.core import _LOGGER, MachineError, listify
from transitions.extensions import GraphMachine
Expand Down Expand Up @@ -34,6 +34,34 @@
NestedState.separator = "↦"


class Transition(str, Enum):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦‍♂️😆

def _generate_next_value_(name: str, *args) -> str:
name = name.replace("__", NestedState.separator)
return name.lower()

initialize = auto()
pause = auto()
recover = auto()
report_error = auto()
report_problem = auto()
report_warning = auto()
restart = auto()
restarting__starting = auto()
restarting__stopping = auto()
resume = auto()
start = auto()
stop = auto()


class Event(str, Enum):
def _generate_next_value_(name: str, *args) -> str:
name = name.replace("__", NestedState.separator)
return f"{name.lower()}_event"

bootup = auto()
shutdown = auto()


# region Enums


Expand Down Expand Up @@ -178,35 +206,41 @@ def __init__(
)

self.add_transition(
"restart",
Transition.restart,
restart_state.ignore,
restart_state.restarting,
after="restart",
after=Transition.restart,
conditions=self._can_restart,
)
self.add_transition(
"restart",
Transition.restart,
restart_state.restarting.value.stopping,
restart_state.restarting.value.starting,
after="restart",
after=Transition.restart,
)
self.add_transition(
"restart", restart_state.restarting.value.starting, restart_state.restarted
Transition.restart,
restart_state.restarting.value.starting,
restart_state.restarted,
)

self.add_transition(
"restart",
Transition.restart,
restart_state.restarted,
restart_state.restarting,
after="restart",
after=Transition.restart,
conditions=self._can_restart,
)

self.on_enter("restarting↦stopping", self._stop_and_wait_for_completion)
self.on_enter("restarting↦starting", self._start_and_wait_for_completion)
self.on_enter(
Transition.restarting__stopping.value, self._stop_and_wait_for_completion
)
self.on_enter(
Transition.restarting__starting.value, self._start_and_wait_for_completion
)

async def _stop_and_wait_for_completion(self, event_data: EventData) -> None:
shutdown_event: Event = anyio.create_event()
shutdown_event: EventType = anyio.create_event()

async with anyio.create_task_group() as task_group:
await task_group.spawn(
Expand All @@ -215,7 +249,7 @@ async def _stop_and_wait_for_completion(self, event_data: EventData) -> None:
await task_group.spawn(shutdown_event.wait)

async def _start_and_wait_for_completion(self, event_data: EventData) -> None:
bootup_event: Event = anyio.create_event()
bootup_event: EventType = anyio.create_event()

async with anyio.create_task_group() as task_group:
await task_group.spawn(
Expand Down Expand Up @@ -296,42 +330,50 @@ def register_parallel_state_machine(self, machine: BaseStateMachine) -> None:
# region Protected API

def _create_crashed_transitions(self, actor_state):
self.add_transition("report_error", "*", actor_state.crashed)
self.add_transition(Transition.report_error, "*", actor_state.crashed)
self.add_transition(
"stop", actor_state.crashed, actor_state.stopping, after="stop"
Transition.stop,
actor_state.crashed,
actor_state.stopping,
after=Transition.stop,
)
self.add_transition(
"start", actor_state.crashed, actor_state.starting, after="start"
Transition.start,
actor_state.crashed,
actor_state.starting,
after=Transition.start,
)

def _create_started_substates_transitions(self, actor_state):
self.add_transition(
"pause", actor_state.started.value.running, actor_state.started.value.paused
Transition.pause,
actor_state.started.value.running,
actor_state.started.value.paused,
)
self.add_transition(
"resume",
Transition.resume,
actor_state.started.value.paused,
actor_state.started.value.running.value.healthy,
)

self.add_transition(
"recover",
Transition.recover,
[
actor_state.started.value.running.value.degraded,
actor_state.started.value.running.value.unhealthy,
],
actor_state.started.value.running.value.healthy,
)
self.add_transition(
"report_warning",
Transition.report_warning,
[
actor_state.started.value.running.value.healthy,
actor_state.started.value.running.value.unhealthy,
],
actor_state.started.value.running.value.degraded,
)
self.add_transition(
"report_problem",
Transition.report_problem,
[
actor_state.started.value.running.value.degraded,
actor_state.started.value.running.value.healthy,
Expand All @@ -341,7 +383,10 @@ def _create_started_substates_transitions(self, actor_state):

def _create_restart_transitions(self, actor_state):
self.add_transition(
"start", actor_state.stopped, actor_state.starting, after="start"
Transition.start,
actor_state.stopped,
actor_state.starting,
after=Transition.start,
)

def _create_shutdown_transitions(self, actor_state):
Expand All @@ -353,27 +398,27 @@ def _create_shutdown_transitions(self, actor_state):
actor_state.stopping.value.resources_released,
actor_state.stopping.value.dependencies_stopped,
],
trigger="stop",
trigger=Transition.stop,
loop=False,
after="stop",
after=Transition.stop,
)
self.add_transition(
"stop",
Transition.stop,
actor_state.stopping.value.dependencies_stopped,
actor_state.stopped,
after=partial(_maybe_set_event, event_name="shutdown_event"),
after=partial(_maybe_set_event, event_name=Event.shutdown),
)

transition = self.get_transitions(
"stop",
Transition.stop,
actor_state.stopping.value.tasks_stopped,
actor_state.stopping.value.resources_released,
)[0]
transition.before.append(_release_resources)

def _create_bootup_transitions(self, actor_state):
self.add_transition(
"initialize", actor_state.initializing, actor_state.initialized
Transition.initialize, actor_state.initializing, actor_state.initialized
)

self.add_ordered_transitions(
Expand All @@ -385,15 +430,15 @@ def _create_bootup_transitions(self, actor_state):
actor_state.starting.value.resources_acquired,
actor_state.starting.value.tasks_started,
],
trigger="start",
trigger=Transition.start,
loop=False,
after="start",
after=Transition.start,
)
self.add_transition(
"start",
Transition.start,
actor_state.starting.value.tasks_started,
actor_state.started,
after=partial(_maybe_set_event, event_name="bootup_event"),
after=partial(_maybe_set_event, event_name=Event.bootup),
)

# endregion
Expand All @@ -409,7 +454,7 @@ async def _release_resources(event_data: transitions.EventData) -> None:
async def _maybe_set_event(event_data: EventData, event_name: str) -> None:
kwargs = _merge_event_data_kwargs(event_data)
try:
event: Event = kwargs[event_name]
event: EventType = kwargs[event_name]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the anyio.Event at import to not conflict with the Event enum. It seemed to me that the having a variable Event.bootup and Event.shutdown was more useful than the type annotation. By changing the name of the annotation at import, mypy will still work, and IDEs will still prompt you to use the correct value. It just avoids the name collision.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from anyio.abc import Event as EventType

await event.set()
except KeyError:
pass
Expand Down