Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ type Plugin struct {
MaxHttpResponseBytes int64
MaxVarBytes int64
log func(LogLevel, string)
httpClient *http.Client
hasWasi bool
guestRuntime guestRuntime
Adapter *observe.AdapterBase
Expand Down
38 changes: 38 additions & 0 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -1508,3 +1509,40 @@ func uintToLEBytes(num uint) []byte {
func uintFromLEBytes(bytes []byte) uint {
return uint(bytes[0]) | uint(bytes[1])<<8 | uint(bytes[2])<<16 | uint(bytes[3])<<24
}

func TestHTTP_customClient(t *testing.T) {
manifest := manifest("http.wasm")
manifest.AllowedHosts = []string{"jsonplaceholder.*.com"}

// Track whether our custom transport was used.
var transportUsed bool
customClient := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
transportUsed = true
return http.DefaultTransport.RoundTrip(req)
}),
}

ctx := context.Background()
plugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{
EnableWasi: true,
HTTPClient: customClient,
}, nil)
require.NoError(t, err)

instance, err := plugin.Instance(ctx, wasiPluginConfig())
require.NoError(t, err)
defer instance.Close(context.Background())

exit, _, err := instance.Call("run_test", []byte{})
require.NoError(t, err)
assert.Equal(t, uint32(0), exit)
assert.True(t, transportUsed, "custom HTTP client transport should have been used")
}

// roundTripFunc allows using a function as an http.RoundTripper.
type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
2 changes: 1 addition & 1 deletion host.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOf
req.Header.Set(key, value)
}

client := http.DefaultClient
client := plugin.httpClient
resp, err := client.Do(req)
if err != nil {
panic(err)
Expand Down
13 changes: 13 additions & 0 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -40,6 +41,7 @@ type CompiledPlugin struct {
maxHttp int64
maxVar int64
enableHttpResponseHeaders bool
httpClient *http.Client
}

type PluginConfig struct {
Expand All @@ -49,6 +51,10 @@ type PluginConfig struct {
ObserveOptions *observe.Options
EnableHttpResponseHeaders bool

// HTTPClient overrides the default HTTP client used for plugin HTTP requests.
// If nil, http.DefaultClient is used.
HTTPClient *http.Client

// ModuleConfig is only used when a plugins are built using the NewPlugin
// function. In this function, the plugin is both compiled, and an instance
// of the plugin is instantiated, and the ModuleConfig is passed to the
Expand Down Expand Up @@ -136,6 +142,11 @@ func NewCompiledPlugin(
}
}

httpClient := config.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}

p := CompiledPlugin{
manifest: manifest,
runtime: wazero.NewRuntimeWithConfig(ctx, runtimeConfig),
Expand All @@ -145,6 +156,7 @@ func NewCompiledPlugin(
modules: make(map[string]wazero.CompiledModule),
maxHttp: calculateMaxHttp(manifest),
maxVar: calculateMaxVar(manifest),
httpClient: httpClient,
}

if config.EnableWasi {
Expand Down Expand Up @@ -399,6 +411,7 @@ func (p *CompiledPlugin) Instance(ctx context.Context, config PluginInstanceConf
close: closers,
extism: extism,
hasWasi: p.hasWasi,
httpClient: p.httpClient,
mainModule: main,
modules: instancedModules,
Timeout: time.Duration(p.manifest.Timeout) * time.Millisecond,
Expand Down