diff --git a/test/nexus_test.go b/test/nexus_test.go index 3267f4c28..248ce3346 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -9,7 +9,6 @@ import ( "os" "slices" "strings" - "sync/atomic" "testing" "time" @@ -1061,24 +1060,50 @@ func TestAsyncOperationFromWorkflow(t *testing.T) { }) } -func runCancellationTypeTest(ctx context.Context, tc *testContext, cancellationType workflow.NexusOperationCancellationType, t *testing.T) (client.WorkflowRun, string, time.Time) { +// cancelTypeOp is a wrapper for a workflow run operation that delays responding to the cancel request so that time +// based assertions aren't flakey. +type cancelTypeOp struct { + nexus.UnimplementedOperation[string, string] + workflowRunOp nexus.Operation[string, string] + unblockCancelCh chan struct{} +} + +func (o *cancelTypeOp) Name() string { + return o.workflowRunOp.Name() +} + +func (o *cancelTypeOp) Start(ctx context.Context, input string, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[string], error) { + return o.workflowRunOp.Start(ctx, input, options) +} + +func (o *cancelTypeOp) Cancel(ctx context.Context, token string, options nexus.CancelOperationOptions) error { + if o.unblockCancelCh != nil { + // Should only be non-nil in the TRY_CANCEL case. + <-o.unblockCancelCh + } + return o.workflowRunOp.Cancel(ctx, token, options) +} + +func runCancellationTypeTest(ctx context.Context, tc *testContext, cancellationType workflow.NexusOperationCancellationType, unblockCancelCh chan struct{}, t *testing.T) (client.WorkflowRun, string, time.Time) { handlerWf := func(ctx workflow.Context, ownID string) (string, error) { err := workflow.Await(ctx, func() bool { return false }) // Delay completion after receiving cancellation so that assertions on end time aren't flakey. disconCtx, _ := workflow.NewDisconnectedContext(ctx) - _ = workflow.Sleep(disconCtx, time.Second) + workflow.GetSignalChannel(disconCtx, "unblock").Receive(disconCtx, nil) return "", err } - handlerID := atomic.Value{} - op := temporalnexus.NewWorkflowRunOperation( - "workflow-op", - handlerWf, - func(ctx context.Context, _ string, soo nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { - handlerID.Store(soo.RequestID) - return client.StartWorkflowOptions{ID: soo.RequestID}, nil - }, - ) + handlerID := uuid.NewString() + op := &cancelTypeOp{ + unblockCancelCh: unblockCancelCh, + workflowRunOp: temporalnexus.NewWorkflowRunOperation( + "workflow-op", + handlerWf, + func(ctx context.Context, _ string, soo nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { + return client.StartWorkflowOptions{ID: handlerID}, nil + }, + ), + } var unblockedTime time.Time callerWf := func(ctx workflow.Context, cancellation workflow.NexusOperationCancellationType) error { @@ -1091,13 +1116,16 @@ func runCancellationTypeTest(ctx context.Context, tc *testContext, cancellationT return err } + disconCtx, _ := workflow.NewDisconnectedContext(ctx) // Use disconnected ctx so it is not auto canceled. if cancellation == workflow.NexusOperationCancellationTypeTryCancel || cancellation == workflow.NexusOperationCancellationTypeWaitRequested { - disconCtx, _ := workflow.NewDisconnectedContext(ctx) // Use disconnected ctx so it is not auto canceled. workflow.Go(disconCtx, func(ctx workflow.Context) { // Wake up the caller so it is not waiting for the operation to complete to get the next WFT. _ = workflow.Sleep(ctx, time.Millisecond) }) } + if cancellation == workflow.NexusOperationCancellationTypeWaitCompleted { + _ = workflow.SignalExternalWorkflow(disconCtx, handlerID, "", "unblock", nil).Get(disconCtx, nil) + } _ = fut.Get(ctx, nil) unblockedTime = workflow.Now(ctx).UTC() @@ -1119,11 +1147,7 @@ func runCancellationTypeTest(ctx context.Context, tc *testContext, cancellationT }, callerWf, cancellationType) require.NoError(t, err) require.Eventuallyf(t, func() bool { - id := handlerID.Load() - if id == nil { - return false - } - _, descErr := tc.client.DescribeWorkflow(ctx, id.(string), "") + _, descErr := tc.client.DescribeWorkflow(ctx, handlerID, "") return descErr == nil }, 2*time.Second, 20*time.Millisecond, "timed out waiting for handler wf to start") require.NoError(t, tc.client.CancelWorkflow(ctx, run.GetID(), run.GetRunID())) @@ -1135,7 +1159,15 @@ func runCancellationTypeTest(ctx context.Context, tc *testContext, cancellationT var canceledErr *temporal.CanceledError require.ErrorAs(t, err, &canceledErr) - return run, handlerID.Load().(string), unblockedTime + if unblockCancelCh != nil { + // Should only be non-nil in the TRY_CANCEL case. + close(unblockCancelCh) + } + if cancellationType != workflow.NexusOperationCancellationTypeWaitCompleted { + require.NoError(t, tc.client.SignalWorkflow(ctx, handlerID, "", "unblock", nil)) + } + + return run, handlerID, unblockedTime } func TestAsyncOperationFromWorkflow_CancellationTypes(t *testing.T) { @@ -1148,7 +1180,7 @@ func TestAsyncOperationFromWorkflow_CancellationTypes(t *testing.T) { defer cancel() tc := newTestContext(t, ctx) - callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeAbandon, t) + callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeAbandon, nil, t) require.NotZero(t, unblockedTime) // Verify that caller never sent a cancellation request. @@ -1172,7 +1204,8 @@ func TestAsyncOperationFromWorkflow_CancellationTypes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultNexusTestTimeout) defer cancel() tc := newTestContext(t, ctx) - callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeTryCancel, t) + unblockCancelCh := make(chan struct{}) + callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeTryCancel, unblockCancelCh, t) // Verify operation future was unblocked after cancel command was recorded. callerHist := tc.client.GetWorkflowHistory(ctx, callerRun.GetID(), callerRun.GetRunID(), false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) @@ -1185,6 +1218,8 @@ func TestAsyncOperationFromWorkflow_CancellationTypes(t *testing.T) { foundRequestedEvent = true require.Greater(t, unblockedTime, event.EventTime.AsTime().UTC()) } + require.NotEqual(t, enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED, event.EventType) + require.NotEqual(t, enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED, event.EventType) callerCloseEvent = event } require.True(t, foundRequestedEvent) @@ -1204,7 +1239,7 @@ func TestAsyncOperationFromWorkflow_CancellationTypes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultNexusTestTimeout) defer cancel() tc := newTestContext(t, ctx) - callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeWaitRequested, t) + callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeWaitRequested, nil, t) // Verify operation future was unblocked after cancel request was delivered. callerHist := tc.client.GetWorkflowHistory(ctx, callerRun.GetID(), callerRun.GetRunID(), false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) @@ -1236,7 +1271,7 @@ func TestAsyncOperationFromWorkflow_CancellationTypes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultNexusTestTimeout) defer cancel() tc := newTestContext(t, ctx) - callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeWaitCompleted, t) + callerRun, handlerID, unblockedTime := runCancellationTypeTest(ctx, tc, workflow.NexusOperationCancellationTypeWaitCompleted, nil, t) // Verify operation future was unblocked after operation was cancelled. callerHist := tc.client.GetWorkflowHistory(ctx, callerRun.GetID(), callerRun.GetRunID(), false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) @@ -1247,7 +1282,7 @@ func TestAsyncOperationFromWorkflow_CancellationTypes(t *testing.T) { require.NoError(t, err) if event.EventType == enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED { foundCancelledEvent = true - require.Greater(t, unblockedTime, event.EventTime.AsTime().UTC()) + require.GreaterOrEqual(t, unblockedTime, event.EventTime.AsTime().UTC()) } callerCloseEvent = event }