diff --git a/cancel_token.go b/cancel_token.go new file mode 100644 index 0000000..2ba4720 --- /dev/null +++ b/cancel_token.go @@ -0,0 +1,24 @@ +package fetch + +import "context" + +type CancelToken struct { + ctx context.Context + cancel context.CancelFunc +} + +func NewCancelToken() *CancelToken { + ctx, cancel := context.WithCancel(context.Background()) + return &CancelToken{ + ctx: ctx, + cancel: cancel, + } +} + +func (t *CancelToken) Cancel() { + t.cancel() +} + +func (t *CancelToken) Context() context.Context { + return t.ctx +} diff --git a/cancel_token_test.go b/cancel_token_test.go new file mode 100644 index 0000000..f28a97c --- /dev/null +++ b/cancel_token_test.go @@ -0,0 +1,35 @@ +package fetch_test + +import ( + "fmt" + "testing" + "time" + + "github.com/tinh-tinh/fetch/v2" +) + +func Test_CancelToken(t *testing.T) { + // Create a cancel token + token := fetch.NewCancelToken() + + instance := fetch.Create(&fetch.Config{ + BaseUrl: "https://httpbin.org", + CancelToken: token.Context(), + }) + + // Start the request + go func() { + resp := instance.Get("/delay/5") + if resp.Error != nil { + fmt.Println("HTTP error:", resp.Error) + } + }() + + // Cancel after 2 seconds + time.Sleep(1 * time.Second) + fmt.Println("Canceling...") + token.Cancel() + + // Wait a bit to see result + time.Sleep(3 * time.Second) +} diff --git a/config.go b/config.go index 67f9d31..3c05e86 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package fetch import ( + "context" "io" "net/http" "net/url" @@ -24,6 +25,8 @@ type Config struct { WithCredentials bool // ResponseType is the response type that will be used for the request ResponseType string + // Cancel token + CancelToken context.Context } // GetConfig returns a new *http.Request with the given method, uri and input. @@ -48,10 +51,15 @@ func (f *Fetch) GetConfig(method string, uri string, input io.Reader) (*http.Req } var req *http.Request + var formInput io.Reader = nil if input != nil { - req, err = http.NewRequest(method, fullUrl.String(), input) + formInput = input + } + + if f.Config.CancelToken != nil { + req, err = http.NewRequestWithContext(f.Config.CancelToken, method, fullUrl.String(), formInput) } else { - req, err = http.NewRequest(method, fullUrl.String(), nil) + req, err = http.NewRequest(method, fullUrl.String(), formInput) } if f.Config.ResponseType == "json" { diff --git a/fetch_test.go b/fetch_test.go index 7c4a3db..ba9e78d 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -1,7 +1,6 @@ package fetch_test import ( - "fmt" "testing" "time" @@ -72,7 +71,6 @@ func Test_Timeout(t *testing.T) { Timeout: 10 * time.Millisecond, }) resp := instance.Get("comments") - fmt.Println(resp) require.NotNil(t, resp.Error) } diff --git a/module_test.go b/module_test.go index b998ba4..b900c97 100644 --- a/module_test.go +++ b/module_test.go @@ -3,6 +3,7 @@ package fetch_test import ( "fmt" "io" + "net/http" "net/http/httptest" "testing" @@ -78,10 +79,17 @@ func Test_ModuleFactory(t *testing.T) { appModule := core.NewModule(core.NewModuleOptions{ Imports: []core.Modules{ fetch.RegisterFactory(func(ref core.RefProvider) *fetch.Config { - return &fetch.Config{} + return &fetch.Config{ + BaseUrl: "https://jsonplaceholder.typicode.com", + Headers: http.Header{"x-api-key": []string{"abcd"}}, + } }), }, }) - fetchModule := fetch.Inject(appModule) - require.NotNil(t, fetchModule) + fetchConfig := fetch.Inject(appModule) + require.NotNil(t, fetchConfig) + + req, err := fetchConfig.GetConfig("GET", "", nil) + require.Nil(t, err) + require.Equal(t, "abcd", req.Header.Values("x-api-key")[0]) }