diff --git a/examples/common/common.go b/examples/common/common.go index 3d6803b..796b4fb 100644 --- a/examples/common/common.go +++ b/examples/common/common.go @@ -15,9 +15,11 @@ package common import ( "errors" "fmt" - "github.com/metaform/dataplane-sdk-go/pkg/dsdk" "net/http" "strings" + + "github.com/go-chi/chi/v5" + "github.com/metaform/dataplane-sdk-go/pkg/dsdk" ) const ( @@ -36,13 +38,27 @@ type TokenResponse struct { // NewSignalingServer creates and returns a new HTTP server configured with dataplane signaling endpoints. func NewSignalingServer(sdkApi *dsdk.DataPlaneApi, port int) *http.Server { - mux := http.NewServeMux() - mux.HandleFunc("/start", sdkApi.Start) - mux.HandleFunc("/prepare", sdkApi.Prepare) - mux.HandleFunc("/terminate/", sdkApi.Terminate) - mux.HandleFunc("/suspend/", sdkApi.Suspend) + r := chi.NewRouter() + r.Post("/dataflows/start", sdkApi.Start) + r.Post("/dataflows/{id}/start", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.StartById(writer, request, id) + }) + r.Post("/dataflows/prepare", sdkApi.Prepare) + r.Post("/dataflows/{id}/terminate", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.Terminate(id, writer, request) + }) + r.Post("/dataflows/{id}/suspend", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.Suspend(id, writer, request) + }) + r.Get("/dataflows/{id}/status", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.Status(id, writer, request) + }) - return &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux} + return &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: r} } // NewDataServer creates and initializes a new HTTP server with a specified port and request handler. diff --git a/examples/controlplane/controlplane.go b/examples/controlplane/controlplane.go index 3769a4f..a284584 100644 --- a/examples/controlplane/controlplane.go +++ b/examples/controlplane/controlplane.go @@ -27,9 +27,9 @@ import ( ) const ( - startUrl = "http://localhost:%d/start" - terminateUrl = "http://localhost:%d/terminate/%s" - consumerPrepareURL = "http://localhost:%d/prepare" + startUrl = "http://localhost:%d/dataflows/start" + terminateUrl = "http://localhost:%d/dataflows/%s/terminate" + consumerPrepareURL = "http://localhost:%d/dataflows/prepare" providerCallbackURL = "http://provider.com/dp/callback" contentType = "Content-Type" jsonContentType = "application/json" diff --git a/examples/streaming-pull-dataplane/launcher/services.go b/examples/streaming-pull-dataplane/launcher/services.go index f7a3d91..ae9873f 100644 --- a/examples/streaming-pull-dataplane/launcher/services.go +++ b/examples/streaming-pull-dataplane/launcher/services.go @@ -13,10 +13,11 @@ package launcher import ( + "log" + "github.com/metaform/dataplane-sdk-go/examples/natsservices" "github.com/metaform/dataplane-sdk-go/examples/streaming-pull-dataplane/consumer" "github.com/metaform/dataplane-sdk-go/examples/streaming-pull-dataplane/provider" - "log" ) func LaunchServices() (*provider.ProviderDataPlane, *consumer.ConsumerDataPlane) { diff --git a/examples/streaming/terminate.go b/examples/streaming/terminate.go index b252378..2a97c82 100644 --- a/examples/streaming/terminate.go +++ b/examples/streaming/terminate.go @@ -14,10 +14,11 @@ package streaming import ( "context" - "github.com/google/uuid" - "github.com/metaform/dataplane-sdk-go/examples/controlplane" "log" "time" + + "github.com/google/uuid" + "github.com/metaform/dataplane-sdk-go/examples/controlplane" ) // TerminateScenario coordinates a simulated data transfer scenario and forcibly terminates it after a predefined duration. diff --git a/examples/sync-pull-dataplane/launcher/services.go b/examples/sync-pull-dataplane/launcher/services.go index b07b123..d690d28 100644 --- a/examples/sync-pull-dataplane/launcher/services.go +++ b/examples/sync-pull-dataplane/launcher/services.go @@ -14,10 +14,11 @@ package launcher import ( "context" - "github.com/metaform/dataplane-sdk-go/examples/sync-pull-dataplane/consumer" - "github.com/metaform/dataplane-sdk-go/examples/sync-pull-dataplane/provider" "sync" "time" + + "github.com/metaform/dataplane-sdk-go/examples/sync-pull-dataplane/consumer" + "github.com/metaform/dataplane-sdk-go/examples/sync-pull-dataplane/provider" ) func LaunchServicesAndWait(wg *sync.WaitGroup) { diff --git a/examples/sync-pull-dataplane/main.go b/examples/sync-pull-dataplane/main.go index 99ac607..3d4f436 100644 --- a/examples/sync-pull-dataplane/main.go +++ b/examples/sync-pull-dataplane/main.go @@ -14,12 +14,13 @@ package main import ( "context" + "log" + "sync" + "github.com/google/uuid" "github.com/metaform/dataplane-sdk-go/examples/controlplane" "github.com/metaform/dataplane-sdk-go/examples/sync-pull-dataplane/consumer" "github.com/metaform/dataplane-sdk-go/examples/sync-pull-dataplane/launcher" - "log" - "sync" ) // Demonstrates initiating a data transfer using a provider data plane that implements synchronous signaling start operations. diff --git a/go.mod b/go.mod index eac8a29..aaa34a1 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/metaform/dataplane-sdk-go go 1.24.1 require ( + github.com/go-chi/chi/v5 v5.2.3 github.com/go-playground/validator/v10 v10.27.0 github.com/google/uuid v1.6.0 github.com/lib/pq v1.10.9 diff --git a/go.sum b/go.sum index a953675..c0f09aa 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= diff --git a/internal/tests/api_integration_test.go b/internal/tests/api_integration_test.go index e9b4dcb..2339e2f 100644 --- a/internal/tests/api_integration_test.go +++ b/internal/tests/api_integration_test.go @@ -11,8 +11,10 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" + "github.com/go-chi/chi/v5" "github.com/google/uuid" _ "github.com/lib/pq" "github.com/metaform/dataplane-sdk-go/pkg/dsdk" @@ -30,14 +32,27 @@ var database *sql.DB func newServerWithSdk(t *testing.T, sdk *dsdk.DataPlaneSDK) http.Handler { t.Helper() sdkApi := dsdk.NewDataPlaneApi(sdk) - mux := http.NewServeMux() - - mux.HandleFunc("/start", sdkApi.Start) - mux.HandleFunc("/prepare", sdkApi.Prepare) - mux.HandleFunc("/terminate/", sdkApi.Terminate) - mux.HandleFunc("/suspend/", sdkApi.Suspend) - mux.HandleFunc("/status", sdkApi.Status) - return mux + r := chi.NewRouter() + + r.Post("/dataflows/start", sdkApi.Start) + r.Post("/dataflows/{id}/start", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.StartById(writer, request, id) + }) + r.Post("/dataflows/prepare", sdkApi.Prepare) + r.Post("/dataflows/{id}/terminate", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.Terminate(id, writer, request) + }) + r.Post("/dataflows/{id}/suspend", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.Suspend(id, writer, request) + }) + r.Get("/dataflows/{id}/status", func(writer http.ResponseWriter, request *http.Request) { + id := chi.URLParam(request, "id") + sdkApi.Status(id, writer, request) + }) + return r } var handler http.Handler @@ -47,7 +62,7 @@ func TestMain(m *testing.M) { database = db t := &testing.T{} - sdk, err := createSdk(db) + sdk, err := newSdk(db) assert.NoError(t, err) handler = newServerWithSdk(t, sdk) code := m.Run() @@ -57,12 +72,12 @@ func TestMain(m *testing.M) { } // E2E tests -func Test_Start_NotExists(t *testing.T) { +func Test_Start_NotYetExists(t *testing.T) { payload, err := serialize(newStartMessage()) assert.NoError(t, err) - req, err := http.NewRequest(http.MethodPost, "/start", bytes.NewBuffer(payload)) + req, err := http.NewRequest(http.MethodPost, "/dataflows/start", bytes.NewBuffer(payload)) assert.NoError(t, err) rr := httptest.NewRecorder() @@ -81,7 +96,7 @@ func Test_Start_InvalidPayload(t *testing.T) { sm.CounterPartyID = "" // should raise a validation error payload, err := serialize(sm) assert.NoError(t, err) - req, err := http.NewRequest(http.MethodPost, "/start", bytes.NewBuffer(payload)) + req, err := http.NewRequest(http.MethodPost, "/dataflows/start", bytes.NewBuffer(payload)) assert.NoError(t, err) rr := httptest.NewRecorder() @@ -90,11 +105,114 @@ func Test_Start_InvalidPayload(t *testing.T) { assert.NotNil(t, rr.Body.String()) } +func Test_StartByID_WhenNotFound(t *testing.T) { + id := uuid.New().String() + + requestBody, err := serialize(newStartByIdMessage()) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+id+"/start", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) + +} + +func Test_StartByID_WhenStartedOrStarting(t *testing.T) { + + states := []dsdk.DataFlowState{ + dsdk.Started, + dsdk.Starting, + } + + for _, state := range states { + id := uuid.New().String() + store := postgres.NewStore(database) + flow, err := newFlowBuilder().ID(id).State(state).Build() + assert.NoError(t, err) + assert.NoError(t, store.Create(ctx, flow)) + + requestBody, err := serialize(newStartByIdMessage()) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+id+"/start", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + found, err := store.FindById(ctx, id) + assert.NoError(t, err) + assert.Equal(t, dsdk.Started, found.State) + } + +} + +func Test_StartByID_WhenPrepared(t *testing.T) { + + tests := []struct { + isConsumer bool + expectedHttpCode int + expectedState dsdk.DataFlowState + }{ + { + isConsumer: true, + expectedHttpCode: http.StatusOK, + expectedState: dsdk.Started, + }, + { + isConsumer: false, + expectedHttpCode: http.StatusBadRequest, + expectedState: dsdk.Prepared, + }, + } + + for _, test := range tests { + id := uuid.New().String() + store := postgres.NewStore(database) + flow, err := newFlowBuilder().ID(id).State(dsdk.Prepared).Consumer(test.isConsumer).Build() + assert.NoError(t, err) + assert.NoError(t, store.Create(ctx, flow)) + + requestBody, err := serialize(newStartByIdMessage()) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+id+"/start", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, test.expectedHttpCode, rr.Code) + found, err := store.FindById(ctx, id) + assert.NoError(t, err) + assert.Equal(t, test.expectedState, found.State) + } + +} + +func Test_StartByID_MissingSourceAddress(t *testing.T) { + requestBody, err := serialize(dsdk.DataFlowStartByIdMessage{}) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/some-id/start", bytes.NewBuffer(requestBody)) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) +} + func Test_Prepare(t *testing.T) { payload, err := serialize(newPrepareMessage()) assert.NoError(t, err) - req, err := http.NewRequest(http.MethodPost, "/prepare", bytes.NewBuffer(payload)) + req, err := http.NewRequest(http.MethodPost, "/dataflows/prepare", bytes.NewBuffer(payload)) assert.NoError(t, err) rr := httptest.NewRecorder() @@ -120,7 +238,7 @@ func Test_Prepare_WrongState(t *testing.T) { payload, err := serialize(message) assert.NoError(t, err) - req, err := http.NewRequest(http.MethodPost, "/prepare", bytes.NewBuffer(payload)) + req, err := http.NewRequest(http.MethodPost, "/dataflows/prepare", bytes.NewBuffer(payload)) assert.NoError(t, err) rr := httptest.NewRecorder() @@ -129,6 +247,153 @@ func Test_Prepare_WrongState(t *testing.T) { assert.Equal(t, http.StatusConflict, rr.Code) } +func Test_Suspend_Success(t *testing.T) { + + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Started).Build() + assert.NoError(t, err) + store := postgres.NewStore(database) + err = store.Create(ctx, flow) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+flow.ID+"/suspend", strings.NewReader("")) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + byId, err := store.FindById(ctx, id) + assert.NoError(t, err) + assert.Equal(t, dsdk.Suspended, byId.State) +} + +func Test_Suspend_WithReason(t *testing.T) { + + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Started).Build() + assert.NoError(t, err) + store := postgres.NewStore(database) + err = store.Create(ctx, flow) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+flow.ID+"/suspend", strings.NewReader(`{"reason": "test reason"}`)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + byId, err := store.FindById(ctx, id) + assert.NoError(t, err) + assert.Equal(t, dsdk.Suspended, byId.State) + assert.Equal(t, "test reason", byId.ErrorDetail) +} + +func Test_Suspend_WhenNotExists(t *testing.T) { + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Started).Build() + assert.NoError(t, err) + //missing: storing the flow + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+flow.ID+"/suspend", strings.NewReader("")) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +func Test_Suspend_WhenNotStarted(t *testing.T) { + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Completed).Build() // completed flows cannot transition to suspended + assert.NoError(t, err) + err = postgres.NewStore(database).Create(ctx, flow) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+flow.ID+"/suspend", strings.NewReader("")) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) +} + +func Test_Terminate_Success(t *testing.T) { + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Started).Build() + assert.NoError(t, err) + store := postgres.NewStore(database) + err = store.Create(ctx, flow) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+flow.ID+"/terminate", strings.NewReader("")) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + byId, err := store.FindById(ctx, id) + assert.NoError(t, err) + assert.Equal(t, dsdk.Terminated, byId.State) + +} + +func Test_Terminate_WithReason(t *testing.T) { + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Started).Build() + assert.NoError(t, err) + store := postgres.NewStore(database) + err = store.Create(ctx, flow) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+flow.ID+"/terminate", strings.NewReader(`{"reason": "test reason"}`)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + byId, err := store.FindById(ctx, id) + assert.NoError(t, err) + assert.Equal(t, dsdk.Terminated, byId.State) + assert.Equal(t, "test reason", byId.ErrorDetail) +} + +func Test_Terminate_WhenNotFound(t *testing.T) { + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Started).Build() + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/dataflows/"+flow.ID+"/terminate", strings.NewReader("")) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +func Test_GetStatus(t *testing.T) { + id := uuid.New().String() + flow, err := newFlowBuilder().ID(id).State(dsdk.Started).Build() + assert.NoError(t, err) + store := postgres.NewStore(database) + err = store.Create(ctx, flow) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/dataflows/"+flow.ID+"/status", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var responseMessage dsdk.DataFlowStatusResponseMessage + err = json.NewDecoder(rr.Body).Decode(&responseMessage) + assert.NoError(t, err) + assert.Equal(t, responseMessage.State, dsdk.Started) +} + +func Test_GetStatus_NotFound(t *testing.T) { + + req, err := http.NewRequest(http.MethodGet, "/dataflows/not-exist/status", nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) +} + func newFlowBuilder() *dsdk.DataFlowBuilder { bldr := &dsdk.DataFlowBuilder{} return bldr.ID("test-id"). @@ -158,7 +423,21 @@ func newStartMessage() dsdk.DataFlowStartMessage { TransferType: newTransferType(), DestinationDataAddress: dsdk.DataAddress{}, }, - SourceDataAddress: &dsdk.DataAddress{}, + SourceDataAddress: &dsdk.DataAddress{ + Properties: map[string]any{ + "foo": "bar", + }, + }, + } +} + +func newStartByIdMessage() dsdk.DataFlowStartByIdMessage { + return dsdk.DataFlowStartByIdMessage{ + SourceDataAddress: &dsdk.DataAddress{ + Properties: map[string]any{ + "foo": "bar", + }, + }, } } @@ -190,7 +469,7 @@ func newCallback() dsdk.CallbackURL { return dsdk.CallbackURL{Scheme: "http", Host: "test.com", Path: "/callback"} } -func createSdk(db *sql.DB) (*dsdk.DataPlaneSDK, error) { +func newSdk(db *sql.DB) (*dsdk.DataPlaneSDK, error) { sdk, err := dsdk.NewDataPlaneSDKBuilder(). Store(postgres.NewStore(db)). TransactionContext(postgres.NewDBTransactionContext(db)). diff --git a/pkg/dsdk/api.go b/pkg/dsdk/api.go index 1a98a15..7eaa612 100644 --- a/pkg/dsdk/api.go +++ b/pkg/dsdk/api.go @@ -13,12 +13,12 @@ package dsdk import ( + "bytes" "encoding/json" "errors" "fmt" + "io" "net/http" - "net/url" - "strings" "github.com/google/uuid" ) @@ -47,12 +47,12 @@ func (d *DataPlaneApi) Prepare(w http.ResponseWriter, r *http.Request) { } if err := prepareMessage.Validate(); err != nil { - d.validationError(err, w) + d.handleError(err, w) } response, err := d.sdk.Prepare(r.Context(), prepareMessage) if err != nil { - d.otherError(err, w) + d.handleError(err, w) return } @@ -78,13 +78,13 @@ func (d *DataPlaneApi) Start(w http.ResponseWriter, r *http.Request) { } if err := startMessage.Validate(); err != nil { - d.validationError(err, w) + d.handleError(err, w) return } response, err := d.sdk.Start(r.Context(), startMessage) if err != nil { - d.otherError(err, w) + d.handleError(err, w) return } @@ -93,94 +93,129 @@ func (d *DataPlaneApi) Start(w http.ResponseWriter, r *http.Request) { code = http.StatusOK } else { code = http.StatusAccepted + w.Header().Set("Location", "/dataflows/"+startMessage.ProcessID) } d.writeResponse(w, code, response) } -func (d *DataPlaneApi) Terminate(w http.ResponseWriter, r *http.Request) { - var terminateMessage DataFlowTransitionMessage +func (d *DataPlaneApi) StartById(w http.ResponseWriter, r *http.Request, id string) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusBadRequest) + return + } + var startMessage DataFlowStartByIdMessage - if err := json.NewDecoder(r.Body).Decode(&terminateMessage); err != nil { + if err := json.NewDecoder(r.Body).Decode(&startMessage); err != nil { d.decodingError(w, err) return } - if err := terminateMessage.Validate(); err != nil { - d.validationError(err, w) + + if err := startMessage.Validate(); err != nil { + d.handleError(err, w) return } - d.transition(w, r, func(processID string) error { - //todo: pass Reason to Terminate - return d.sdk.Terminate(r.Context(), processID) - }) -} -func (d *DataPlaneApi) Suspend(w http.ResponseWriter, r *http.Request) { - var suspendMessage DataFlowTransitionMessage - - if err := json.NewDecoder(r.Body).Decode(&suspendMessage); err != nil { - d.decodingError(w, err) + response, err := d.sdk.StartById(r.Context(), id, startMessage) + if err != nil { + d.handleError(err, w) return } - if err := suspendMessage.Validate(); err != nil { - d.validationError(err, w) - return + + var code int + if response.State == Started { + code = http.StatusOK + } else { + code = http.StatusAccepted + w.Header().Set("Location", "/dataflows/"+id) } - d.transition(w, r, func(processID string) error { - //todo: pass Reason to Suspend - return d.sdk.Suspend(r.Context(), processID) - }) + d.writeResponse(w, code, response) } -func (d *DataPlaneApi) Status(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid request method", http.StatusBadRequest) - return - } - processID, err := ParseIDFromURL(r.URL) +func (d *DataPlaneApi) Terminate(id string, w http.ResponseWriter, r *http.Request) { + reason := "" + // Peek into the body + bodyBytes, err := io.ReadAll(r.Body) if err != nil { + d.decodingError(w, err) return } - dataFlow, err := d.sdk.Status(r.Context(), processID) - if err != nil { - d.otherError(err, w) + // if a body was sent, parse it, read the reason + + if len(bodyBytes) > 0 { + var terminateMessage DataFlowTransitionMessage + + if err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(&terminateMessage); err != nil { + d.decodingError(w, err) + return + } + if err := terminateMessage.Validate(); err != nil { + d.handleError(err, w) + return + } + reason = terminateMessage.Reason + } + terminateError := d.sdk.Terminate(r.Context(), id, reason) + if terminateError != nil { + d.handleError(terminateError, w) return } + w.Header().Set(contentType, jsonContentType) - response := DataFlowStatusResponseMessage{ - State: dataFlow.State, - DataFlowID: dataFlow.ID, - } - d.writeResponse(w, 200, response) + w.WriteHeader(http.StatusOK) } -func (d *DataPlaneApi) transition(w http.ResponseWriter, r *http.Request, transition func(processID string) error) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusBadRequest) - return - } +func (d *DataPlaneApi) Suspend(id string, w http.ResponseWriter, r *http.Request) { - processID, err := ParseIDFromURL(r.URL) + reason := "" + // Peek into the body + bodyBytes, err := io.ReadAll(r.Body) if err != nil { - d.otherError(err, w) + d.decodingError(w, err) return } + // if a body was sent, parse it, read the reason + if len(bodyBytes) > 0 { + var suspendMessage DataFlowTransitionMessage - var terminateMessage DataFlowTransitionMessage + if err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(&suspendMessage); err != nil { + d.decodingError(w, err) + return + } + if err := suspendMessage.Validate(); err != nil { + d.handleError(err, w) + return + } + reason = suspendMessage.Reason + } - if err := json.NewDecoder(r.Body).Decode(&terminateMessage); err != nil { - d.decodingError(w, err) + suspensionError := d.sdk.Suspend(r.Context(), id, reason) + if suspensionError != nil { + d.handleError(suspensionError, w) return } - err = transition(processID) + w.Header().Set(contentType, jsonContentType) + w.WriteHeader(http.StatusOK) + +} + +func (d *DataPlaneApi) Status(processID string, w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid request method", http.StatusBadRequest) + return + } + dataFlow, err := d.sdk.Status(r.Context(), processID) if err != nil { - d.otherError(err, w) + d.handleError(err, w) return } - w.Header().Set(contentType, jsonContentType) - w.WriteHeader(http.StatusOK) + response := DataFlowStatusResponseMessage{ + State: dataFlow.State, + DataFlowID: dataFlow.ID, + } + d.writeResponse(w, http.StatusOK, response) } func (d *DataPlaneApi) decodingError(w http.ResponseWriter, err error) { @@ -189,10 +224,14 @@ func (d *DataPlaneApi) decodingError(w http.ResponseWriter, err error) { d.writeResponse(w, http.StatusBadRequest, &DataFlowResponseMessage{Error: fmt.Sprintf("Failed to decode request body [%s]", id)}) } -// otherError writes an error message to the HTTP response that indicates "any other" error, such as 409, 500, etc. -func (d *DataPlaneApi) otherError(err error, w http.ResponseWriter) { +// handleError writes an error message to the HTTP response that indicates "any other" error, such as 409, 500, etc. +func (d *DataPlaneApi) handleError(err error, w http.ResponseWriter) { switch { + case errors.Is(err, ErrValidation), errors.Is(err, ErrInvalidTransition): + d.badRequest(err.Error(), w) + case errors.Is(err, ErrNotFound): + d.writeResponse(w, http.StatusNotFound, &DataFlowResponseMessage{Error: err.Error()}) case errors.Is(err, ErrConflict): message := fmt.Sprintf("%s", err) d.writeResponse(w, http.StatusConflict, &DataFlowResponseMessage{Error: message}) @@ -202,13 +241,9 @@ func (d *DataPlaneApi) otherError(err error, w http.ResponseWriter) { d.writeResponse(w, http.StatusInternalServerError, &DataFlowResponseMessage{Error: message}) } } -func (d *DataPlaneApi) validationError(err error, w http.ResponseWriter) { - if errors.Is(err, ErrValidation) { - message := fmt.Sprintf("Validation error: %s", err) - d.writeResponse(w, http.StatusBadRequest, &DataFlowResponseMessage{Error: message}) - } else { - d.otherError(err, w) - } + +func (d *DataPlaneApi) badRequest(errMsg string, w http.ResponseWriter) { + d.writeResponse(w, http.StatusBadRequest, &DataFlowResponseMessage{Error: errMsg}) } func (d *DataPlaneApi) writeResponse(w http.ResponseWriter, code int, response any) { @@ -222,29 +257,3 @@ func (d *DataPlaneApi) writeResponse(w http.ResponseWriter, code int, response a return } } - -func ParseIDFromURL(u *url.URL) (string, error) { - if u == nil { - return "", errors.New("URL cannot be nil") - } - - path := u.Path - if path == "" { - return "", errors.New("URL path is empty") - } - - // Remove trailing slash if present - path = strings.TrimSuffix(path, "/") - - // Split the path by '/' to get path segments - pathParts := strings.Split(path, "/") - - // Find the last non-empty segment - for i := len(pathParts) - 1; i >= 0; i-- { - if pathParts[i] != "" { - return pathParts[i], nil - } - } - - return "", errors.New("no valid ID found in URL path") -} diff --git a/pkg/dsdk/api_test.go b/pkg/dsdk/api_test.go deleted file mode 100644 index 135345a..0000000 --- a/pkg/dsdk/api_test.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright (c) 2025 Metaform Systems, Inc -// -// This program and the accompanying materials are made available under the -// terms of the Apache License, Version 2.0 which is available at -// https://www.apache.org/licenses/LICENSE-2.0 -// -// SPDX-License-Identifier: Apache-2.0 -// -// Contributors: -// Metaform Systems, Inc. - initial API and implementation -// - -package dsdk - -import ( - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseIDFromURL(t *testing.T) { - tests := []struct { - name string - urlString string - expectedID string - expectError bool - }{ - { - name: "valid URL with single path segment", - urlString: "http://example.com/12345", - expectedID: "12345", - expectError: false, - }, - { - name: "valid URL with multiple path segments", - urlString: "http://example.com/api/v1/flows/12345", - expectedID: "12345", - expectError: false, - }, - { - name: "valid URL with trailing slash", - urlString: "http://example.com/api/flows/12345/", - expectedID: "12345", - expectError: false, - }, - { - name: "valid URL with UUID", - urlString: "http://example.com/flows/123e4567-e89b-12d3-a456-426614174000", - expectedID: "123e4567-e89b-12d3-a456-426614174000", - expectError: false, - }, - { - name: "path with root slash only", - urlString: "http://example.com/", - expectedID: "", - expectError: true, - }, - { - name: "empty path", - urlString: "http://example.com", - expectedID: "", - expectError: true, - }, - { - name: "path with only slashes", - urlString: "http://example.com///", - expectedID: "", - expectError: true, - }, - { - name: "path with empty segments in middle", - urlString: "http://example.com/api//flows//12345", - expectedID: "12345", - expectError: false, - }, - { - name: "relative path", - urlString: "/api/flows/12345", - expectedID: "12345", - expectError: false, - }, - { - name: "just the ID", - urlString: "/12345", - expectedID: "12345", - expectError: false, - }, - { - name: "complex path with query parameters", - urlString: "http://example.com/api/flows/12345?param=value", - expectedID: "12345", - expectError: false, - }, - { - name: "path with fragment", - urlString: "http://example.com/api/flows/12345#section", - expectedID: "12345", - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - parsedURL, err := url.Parse(tt.urlString) - require.NoError(t, err, "Failed to parse test URL") - - id, err := ParseIDFromURL(parsedURL) - - if tt.expectError { - assert.Error(t, err) - assert.Empty(t, id) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expectedID, id) - } - }) - } -} - -func TestParseIDFromURL_NilURL(t *testing.T) { - id, err := ParseIDFromURL(nil) - - assert.Error(t, err) - assert.Empty(t, id) - assert.Contains(t, err.Error(), "URL cannot be nil") -} - -func TestParseIDFromURL_EdgeCases(t *testing.T) { - t.Run("URL with port", func(t *testing.T) { - u, err := url.Parse("http://example.com:8080/api/flows/12345") - require.NoError(t, err) - - id, err := ParseIDFromURL(u) - assert.NoError(t, err) - assert.Equal(t, "12345", id) - }) - - t.Run("URL with special characters in ID", func(t *testing.T) { - u, err := url.Parse("http://example.com/api/flows/test-id_123") - require.NoError(t, err) - - id, err := ParseIDFromURL(u) - assert.NoError(t, err) - assert.Equal(t, "test-id_123", id) - }) - - t.Run("very long path", func(t *testing.T) { - longPath := "/api/v1/dataplane/flows/processes/instances/executions/12345" - u, err := url.Parse("http://example.com" + longPath) - require.NoError(t, err) - - id, err := ParseIDFromURL(u) - assert.NoError(t, err) - assert.Equal(t, "12345", id) - }) -} diff --git a/pkg/dsdk/dsdk.go b/pkg/dsdk/dsdk.go index 9984c99..0ee2398 100644 --- a/pkg/dsdk/dsdk.go +++ b/pkg/dsdk/dsdk.go @@ -121,41 +121,7 @@ func (dsdk *DataPlaneSDK) Start(ctx context.Context, message DataFlowStartMessag return fmt.Errorf("performing de-duplication for %s: %w", processID, err) } - switch { - case flow != nil && (flow.State == Starting || flow.State == Started): - // duplicate message, pass to handler to generate a data address if needed - response, err = dsdk.onStart(ctx, flow, dsdk, &ProcessorOptions{Duplicate: true, SourceDataAddress: message.SourceDataAddress}) - if err != nil { - return fmt.Errorf("processing data flow: %w", err) - } - - err = dsdk.startState(response, flow) - if err != nil { - return fmt.Errorf("onStart returned an invalid state: %w", err) - } - - if err := dsdk.Store.Save(ctx, flow); err != nil { - return fmt.Errorf("creating data flow: %w", err) - } - return nil - case flow != nil && flow.Consumer && flow.State == Prepared: - // consumer side, process - response, err = dsdk.onStart(ctx, flow, dsdk, &ProcessorOptions{SourceDataAddress: message.SourceDataAddress}) - if err != nil { - return fmt.Errorf("processing data flow: %w", err) - } - - err = dsdk.startState(response, flow) - if err != nil { - return fmt.Errorf("onStart returned an invalid state: %w", err) - } - - if err := dsdk.Store.Save(ctx, flow); err != nil { - return fmt.Errorf("updating data flow: %w", err) - } - - return nil - case flow == nil: + if flow == nil { // provider side, process flow, err = NewDataFlowBuilder().ID(processID). State(Starting). @@ -184,16 +150,38 @@ func (dsdk *DataPlaneSDK) Start(ctx context.Context, message DataFlowStartMessag return fmt.Errorf("creating data flow: %w", err) } return nil - default: - return fmt.Errorf("data flow %s is not in STARTED state: %s", flow.ID, flow.State) } + + response, err = dsdk.startExistingFlow(ctx, flow, message.SourceDataAddress) + return err }) return response, err } -func (dsdk *DataPlaneSDK) Terminate(ctx context.Context, processID string) error { +func (dsdk *DataPlaneSDK) StartById(ctx context.Context, processID string, message DataFlowStartByIdMessage) (*DataFlowResponseMessage, error) { + var response *DataFlowResponseMessage + + err := dsdk.execute(ctx, func(ctx context.Context) error { + existingFlow, err := dsdk.Store.FindById(ctx, processID) + if err != nil && !errors.Is(err, ErrNotFound) { + return fmt.Errorf("performing de-duplication for %s: %w", processID, err) + } + + if existingFlow == nil { // this should never happen -> the store would return an error + return ErrNotFound + } + + response, err = dsdk.startExistingFlow(ctx, existingFlow, message.SourceDataAddress) + return err + + }) + return response, err + +} + +func (dsdk *DataPlaneSDK) Terminate(ctx context.Context, processID string, reason string) error { if processID == "" { return errors.New("processID cannot be empty") } @@ -212,7 +200,7 @@ func (dsdk *DataPlaneSDK) Terminate(ctx context.Context, processID string) error return fmt.Errorf("terminating data flow %s: %w", flow.ID, err) } - err = flow.TransitionToTerminated() + err = flow.TransitionToTerminated(reason) if err != nil { return err } @@ -225,7 +213,7 @@ func (dsdk *DataPlaneSDK) Terminate(ctx context.Context, processID string) error }) } -func (dsdk *DataPlaneSDK) Suspend(ctx context.Context, processID string) error { +func (dsdk *DataPlaneSDK) Suspend(ctx context.Context, processID string, reason string) error { if processID == "" { return errors.New("processID cannot be empty") } @@ -243,7 +231,7 @@ func (dsdk *DataPlaneSDK) Suspend(ctx context.Context, processID string) error { if err := dsdk.onSuspend(ctx, flow); err != nil { return fmt.Errorf("suspending data flow %s: %w", flow.ID, err) } - err = flow.TransitionToSuspended() + err = flow.TransitionToSuspended(reason) if err != nil { return err } @@ -270,6 +258,47 @@ func (dsdk *DataPlaneSDK) Status(ctx context.Context, id string) (*DataFlow, err return flow, err } +func (dsdk *DataPlaneSDK) startExistingFlow(ctx context.Context, flow *DataFlow, sourceAddress *DataAddress) (*DataFlowResponseMessage, error) { + switch { + case flow != nil && (flow.State == Starting || flow.State == Started): + // duplicate message, pass to handler to generate a data address if needed + response, err := dsdk.onStart(ctx, flow, dsdk, &ProcessorOptions{Duplicate: true, SourceDataAddress: sourceAddress}) + if err != nil { + return nil, fmt.Errorf("processing data flow: %w", err) + } + + err = dsdk.startState(response, flow) + if err != nil { + return nil, fmt.Errorf("onStart returned an invalid state: %w", err) + } + + if err := dsdk.Store.Save(ctx, flow); err != nil { + return nil, fmt.Errorf("creating data flow: %w", err) + } + return response, err + case flow != nil && flow.Consumer && flow.State == Prepared: + // consumer side, process + response, err := dsdk.onStart(ctx, flow, dsdk, &ProcessorOptions{SourceDataAddress: sourceAddress}) + if err != nil { + return nil, fmt.Errorf("processing data flow: %w", err) + } + + err = dsdk.startState(response, flow) + if err != nil { + return nil, fmt.Errorf("onStart returned an invalid state: %w", err) + } + + if err := dsdk.Store.Save(ctx, flow); err != nil { + return nil, fmt.Errorf("updating data flow: %w", err) + } + + return response, nil + + default: + return nil, fmt.Errorf("%w: data flow %s is not in STARTED state: %s", ErrInvalidTransition, flow.ID, flow.State) + } +} + func (dsdk *DataPlaneSDK) startState(response *DataFlowResponseMessage, flow *DataFlow) error { if response.State == Started { err := flow.TransitionToStarted() diff --git a/pkg/dsdk/dsdk_test.go b/pkg/dsdk/dsdk_test.go index 440cc64..d5128dc 100644 --- a/pkg/dsdk/dsdk_test.go +++ b/pkg/dsdk/dsdk_test.go @@ -406,7 +406,7 @@ func Test_DataPlaneSDK_Terminate(t *testing.T) { return df.State == Terminated })).Return(nil) - err := dsdk.Terminate(ctx, "flow123") + err := dsdk.Terminate(ctx, "flow123", "") assert.NoError(t, err) } @@ -424,7 +424,7 @@ func Test_DataPlaneSDK_Terminate_NotFound(t *testing.T) { ctx := context.Background() store.EXPECT().FindById(ctx, "flow123").Return(nil, ErrNotFound) - err := dsdk.Terminate(ctx, "flow123") + err := dsdk.Terminate(ctx, "flow123", "") assert.ErrorContains(t, err, "not found") } @@ -449,7 +449,7 @@ func Test_DataPlaneSDK_Terminate_AlreadyTerminated(t *testing.T) { // no transition and no save call expected - err := dsdk.Terminate(ctx, "flow123") + err := dsdk.Terminate(ctx, "flow123", "") assert.NoError(t, err) } @@ -471,7 +471,7 @@ func Test_DataPlaneSDK_Terminate_SdkCallbackError(t *testing.T) { State: Started, }, nil) - err := dsdk.Terminate(ctx, "flow123") + err := dsdk.Terminate(ctx, "flow123", "") assert.ErrorContains(t, err, "some error") } @@ -497,7 +497,7 @@ func Test_DataPlaneSDK_Suspend(t *testing.T) { return df.State == Suspended })).Return(nil) - err := dsdk.Suspend(ctx, "flow123") + err := dsdk.Suspend(ctx, "flow123", "") assert.NoError(t, err) } @@ -515,7 +515,7 @@ func Test_DataPlaneSDK_Suspend_NotFound(t *testing.T) { ctx := context.Background() store.EXPECT().FindById(ctx, "flow123").Return(nil, ErrNotFound) - err := dsdk.Suspend(ctx, "flow123") + err := dsdk.Suspend(ctx, "flow123", "") assert.ErrorContains(t, err, "not found") } @@ -540,7 +540,7 @@ func Test_DataPlaneSDK_Suspend_AlreadySuspended(t *testing.T) { // no transition and no save call expected - err := dsdk.Suspend(ctx, "flow123") + err := dsdk.Suspend(ctx, "flow123", "") assert.NoError(t, err) } @@ -562,7 +562,7 @@ func Test_DataPlaneSDK_Suspend_SdkCallbackError(t *testing.T) { State: Started, }, nil) - err := dsdk.Suspend(ctx, "flow123") + err := dsdk.Suspend(ctx, "flow123", "") assert.ErrorContains(t, err, "some error") } diff --git a/pkg/dsdk/errors.go b/pkg/dsdk/errors.go index ccac17c..902b0a6 100644 --- a/pkg/dsdk/errors.go +++ b/pkg/dsdk/errors.go @@ -15,6 +15,8 @@ var ( ErrNotFound = errors.New("not found") // ErrInvalidInput Sentinel error to indicate a wrong input, e.g. a string when a number was expected, or an empty string ErrInvalidInput = errors.New("invalid input") + // ErrInvalidTransition Sentinel error to indicate an invalid state transition, e.g. of a data flow + ErrInvalidTransition = errors.New("invalid transition") ) // NewValidationError Helper to create new ValidationError diff --git a/pkg/dsdk/messages.go b/pkg/dsdk/messages.go index c5acf95..6115542 100644 --- a/pkg/dsdk/messages.go +++ b/pkg/dsdk/messages.go @@ -13,7 +13,7 @@ type DataFlowBaseMessage struct { DataspaceContext string `json:"dataspaceContext" validate:"required"` ProcessID string `json:"processID" validate:"required"` AgreementID string `json:"agreementID" validate:"required"` - DatasetID string `json:"datasetID" validate:"required"` + DatasetID string `json:"datasetID"` CallbackAddress CallbackURL `json:"callbackAddress" validate:"required,callback-url"` TransferType TransferType `json:"transferType" validate:"required"` DestinationDataAddress DataAddress `json:"destinationDataAddress" validate:"required"` @@ -36,7 +36,7 @@ func (d *DataFlowBaseMessage) Validate() error { type DataFlowStartMessage struct { DataFlowBaseMessage - SourceDataAddress *DataAddress `json:"sourceDataAddress,omitempty" validate:"required"` + SourceDataAddress *DataAddress `json:"sourceDataAddress,omitempty"` } func (d *DataFlowStartMessage) Validate() error { @@ -51,6 +51,18 @@ func (d *DataFlowStartMessage) Validate() error { return nil } +type DataFlowStartByIdMessage struct { + SourceDataAddress *DataAddress `json:"sourceDataAddress,omitempty" validate:"required"` +} + +func (d *DataFlowStartByIdMessage) Validate() error { + err := v.Struct(d) + if err != nil { + return WrapValidationError(err) + } + return nil +} + type DataFlowPrepareMessage struct { DataFlowBaseMessage } diff --git a/pkg/dsdk/messages_test.go b/pkg/dsdk/messages_test.go index 9d9a469..51f0a5a 100644 --- a/pkg/dsdk/messages_test.go +++ b/pkg/dsdk/messages_test.go @@ -55,7 +55,7 @@ func Test_StartMessage_MissingProperties(t *testing.T) { func Test_StartMessage_MissingSourceDataAddress(t *testing.T) { startMsg := DataFlowStartMessage{DataFlowBaseMessage: newBaseMessage()} - assert.ErrorIs(t, startMsg.Validate(), ErrValidation) + assert.NoError(t, startMsg.Validate()) } func Test_PrepareMessage_Success(t *testing.T) { diff --git a/pkg/dsdk/model.go b/pkg/dsdk/model.go index 05a389b..7723e7c 100644 --- a/pkg/dsdk/model.go +++ b/pkg/dsdk/model.go @@ -143,7 +143,7 @@ func (df *DataFlow) TransitionToPreparing() error { return nil } if df.State != Uninitialized { - return fmt.Errorf("invalid transition: cannot transition from %v to PREPARING", df.State) + return fmt.Errorf("%w: cannot transition from %v to PREPARING", ErrInvalidTransition, df.State) } df.State = Preparing df.StateTimestamp = time.Now().UnixMilli() @@ -156,7 +156,7 @@ func (df *DataFlow) TransitionToPrepared() error { return nil } if df.State != Uninitialized && df.State != Preparing { - return fmt.Errorf("invalid transition: cannot transition from %v to PREPARED", df.State) + return fmt.Errorf("%w: cannot transition from %v to PREPARED", ErrInvalidTransition, df.State) } df.State = Prepared df.StateTimestamp = time.Now().UnixMilli() @@ -169,7 +169,7 @@ func (df *DataFlow) TransitionToStarting() error { return nil } if df.State != Uninitialized && df.State != Prepared { - return fmt.Errorf("invalid transition: cannot transition from %v to STARTING", df.State) + return fmt.Errorf("%w: cannot transition from %v to STARTING", ErrInvalidTransition, df.State) } df.State = Starting df.StateTimestamp = time.Now().UnixMilli() @@ -182,7 +182,7 @@ func (df *DataFlow) TransitionToStarted() error { return nil } if df.State != Uninitialized && df.State != Prepared && df.State != Starting && df.State != Suspended { - return fmt.Errorf("invalid transition: cannot transition from %v to STARTED", df.State) + return fmt.Errorf("%w: cannot transition from %v to STARTED", ErrInvalidTransition, df.State) } df.State = Started df.StateTimestamp = time.Now().UnixMilli() @@ -190,14 +190,16 @@ func (df *DataFlow) TransitionToStarted() error { return nil } -func (df *DataFlow) TransitionToSuspended() error { +func (df *DataFlow) TransitionToSuspended(reason string) error { if df.State == Suspended { return nil } if df.State != Started { - return fmt.Errorf("invalid transition: cannot transition from %v to SUSPENDED", df.State) + return fmt.Errorf("%w: cannot transition from %v to SUSPENDED", ErrInvalidTransition, df.State) } df.State = Suspended + //todo: what to do with the reason string? + df.ErrorDetail = reason df.StateTimestamp = time.Now().UnixMilli() df.StateCount++ return nil @@ -208,7 +210,7 @@ func (df *DataFlow) TransitionToCompleted() error { return nil } if df.State != Started { - return fmt.Errorf("invalid transition: cannot transition from %v to COMPLETED", df.State) + return fmt.Errorf("%w: cannot transition from %v to COMPLETED", ErrInvalidTransition, df.State) } df.State = Completed df.StateTimestamp = time.Now().UnixMilli() @@ -216,12 +218,13 @@ func (df *DataFlow) TransitionToCompleted() error { return nil } -func (df *DataFlow) TransitionToTerminated() error { +func (df *DataFlow) TransitionToTerminated(reason string) error { if df.State == Terminated { return nil // todo: does returning an error make sense here? } // Any state can transition to terminated df.State = Terminated + df.ErrorDetail = reason df.StateTimestamp = time.Now().UnixMilli() df.StateCount++ return nil diff --git a/pkg/dsdk/model_transition_test.go b/pkg/dsdk/model_transition_test.go index 4cbc7f7..71c5b0c 100644 --- a/pkg/dsdk/model_transition_test.go +++ b/pkg/dsdk/model_transition_test.go @@ -479,7 +479,7 @@ func TestDataFlow_transitionToSuspended(t *testing.T) { initialStateCount := df.StateCount initialTimestamp := df.StateTimestamp - err := df.TransitionToSuspended() + err := df.TransitionToSuspended("test-reason") if tc.expectErr { if err == nil { @@ -678,7 +678,7 @@ func TestDataFlow_transitionToTerminated(t *testing.T) { initialStateCount := df.StateCount initialTimestamp := df.StateTimestamp - err := df.TransitionToTerminated() + err := df.TransitionToTerminated("test-reason") if tc.expectErr { if err == nil {