diff --git a/extism.go b/extism.go index cf0383f..e0c2b3c 100644 --- a/extism.go +++ b/extism.go @@ -125,6 +125,7 @@ type Plugin struct { MaxHttpResponseBytes int64 MaxVarBytes int64 log func(LogLevel, string) + httpClient *http.Client hasWasi bool guestRuntime guestRuntime Adapter *observe.AdapterBase diff --git a/extism_test.go b/extism_test.go index 5d51a9a..848488c 100644 --- a/extism_test.go +++ b/extism_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "log" + "net/http" "os" "strings" "sync" @@ -1508,3 +1509,40 @@ func uintToLEBytes(num uint) []byte { func uintFromLEBytes(bytes []byte) uint { return uint(bytes[0]) | uint(bytes[1])<<8 | uint(bytes[2])<<16 | uint(bytes[3])<<24 } + +func TestHTTP_customClient(t *testing.T) { + manifest := manifest("http.wasm") + manifest.AllowedHosts = []string{"jsonplaceholder.*.com"} + + // Track whether our custom transport was used. + var transportUsed bool + customClient := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + transportUsed = true + return http.DefaultTransport.RoundTrip(req) + }), + } + + ctx := context.Background() + plugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{ + EnableWasi: true, + HTTPClient: customClient, + }, nil) + require.NoError(t, err) + + instance, err := plugin.Instance(ctx, wasiPluginConfig()) + require.NoError(t, err) + defer instance.Close(context.Background()) + + exit, _, err := instance.Call("run_test", []byte{}) + require.NoError(t, err) + assert.Equal(t, uint32(0), exit) + assert.True(t, transportUsed, "custom HTTP client transport should have been used") +} + +// roundTripFunc allows using a function as an http.RoundTripper. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/host.go b/host.go index c22e245..927bae8 100644 --- a/host.go +++ b/host.go @@ -569,7 +569,7 @@ func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOf req.Header.Set(key, value) } - client := http.DefaultClient + client := plugin.httpClient resp, err := client.Do(req) if err != nil { panic(err) diff --git a/plugin.go b/plugin.go index 605c845..8441ad9 100644 --- a/plugin.go +++ b/plugin.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net/http" "os" "strings" "sync/atomic" @@ -40,6 +41,7 @@ type CompiledPlugin struct { maxHttp int64 maxVar int64 enableHttpResponseHeaders bool + httpClient *http.Client } type PluginConfig struct { @@ -49,6 +51,10 @@ type PluginConfig struct { ObserveOptions *observe.Options EnableHttpResponseHeaders bool + // HTTPClient overrides the default HTTP client used for plugin HTTP requests. + // If nil, http.DefaultClient is used. + HTTPClient *http.Client + // ModuleConfig is only used when a plugins are built using the NewPlugin // function. In this function, the plugin is both compiled, and an instance // of the plugin is instantiated, and the ModuleConfig is passed to the @@ -136,6 +142,11 @@ func NewCompiledPlugin( } } + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + p := CompiledPlugin{ manifest: manifest, runtime: wazero.NewRuntimeWithConfig(ctx, runtimeConfig), @@ -145,6 +156,7 @@ func NewCompiledPlugin( modules: make(map[string]wazero.CompiledModule), maxHttp: calculateMaxHttp(manifest), maxVar: calculateMaxVar(manifest), + httpClient: httpClient, } if config.EnableWasi { @@ -399,6 +411,7 @@ func (p *CompiledPlugin) Instance(ctx context.Context, config PluginInstanceConf close: closers, extism: extism, hasWasi: p.hasWasi, + httpClient: p.httpClient, mainModule: main, modules: instancedModules, Timeout: time.Duration(p.manifest.Timeout) * time.Millisecond,