Skip to content

Commit 2a8f607

Browse files
authored
Expose DNS resolver functionality for extensions (#5421)
* refactor: change getDialAddr to return *types.Host instead of string This simplifies the internal API by returning the structured Host type directly rather than converting it to a string. The caller can calls .String() on the returned Host when needed. This change makes it easier for future code to access the IP and port components separately without needing to parse the string representation. * feat: add AddrResolver interface and implement ResolveAddr method in Dialer * fix AddrResolver interface to return port as well
1 parent a1faed5 commit 2a8f607

File tree

4 files changed

+178
-7
lines changed

4 files changed

+178
-7
lines changed

lib/netext/dialer.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,32 @@ func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn,
6161
if err != nil {
6262
return nil, err
6363
}
64-
conn, err := d.Dialer.DialContext(ctx, proto, dialAddr)
64+
conn, err := d.Dialer.DialContext(ctx, proto, dialAddr.String())
6565
if err != nil {
6666
return nil, err
6767
}
6868
conn = &Conn{conn, &d.BytesRead, &d.BytesWritten}
6969
return conn, err
7070
}
7171

72+
// ResolveAddr looks up the IP address for the given host and optionally port.
73+
// The address is expected in the form "host:port" or just "host".
74+
// It returns the resolved IP, and an error if any.
75+
func (d *Dialer) ResolveAddr(addr string) (net.IP, int, error) {
76+
// Check if the address has a port, if not add a dummy port for parsing
77+
if _, _, err := net.SplitHostPort(addr); err != nil {
78+
// Address doesn't have a port, add a dummy one
79+
addr = net.JoinHostPort(addr, "0")
80+
}
81+
82+
remote, err := d.getDialAddr(addr)
83+
if err != nil {
84+
return nil, 0, err
85+
}
86+
87+
return remote.IP, remote.Port, nil
88+
}
89+
7290
// IOSamples returns samples for data send and received since it last call and zeros out.
7391
// It uses the provided time as the sample time and tags and builtinMetrics to build the samples.
7492
func (d *Dialer) IOSamples(
@@ -98,19 +116,19 @@ func (d *Dialer) IOSamples(
98116
})
99117
}
100118

101-
func (d *Dialer) getDialAddr(addr string) (string, error) {
119+
func (d *Dialer) getDialAddr(addr string) (*types.Host, error) {
102120
remote, err := d.findRemote(addr)
103121
if err != nil {
104-
return "", err
122+
return nil, err
105123
}
106124

107125
for _, ipnet := range d.Blacklist {
108126
if ipnet.Contains(remote.IP) {
109-
return "", BlackListedIPError{ip: remote.IP, net: ipnet}
127+
return nil, BlackListedIPError{ip: remote.IP, net: ipnet}
110128
}
111129
}
112130

113-
return remote.String(), nil
131+
return remote, nil
114132
}
115133

116134
func (d *Dialer) findRemote(addr string) (*types.Host, error) {

lib/netext/dialer_test.go

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func TestDialerAddr(t *testing.T) {
6666
require.EqualError(t, err, tc.expErr)
6767
} else {
6868
require.NoError(t, err)
69-
require.Equal(t, tc.expAddress, addr)
69+
require.Equal(t, tc.expAddress, addr.String())
7070
}
7171
})
7272
}
@@ -103,7 +103,7 @@ func TestDialerAddrBlockHostnamesStar(t *testing.T) {
103103
require.EqualError(t, err, tc.expErr)
104104
} else {
105105
require.NoError(t, err)
106-
require.Equal(t, tc.expAddress, addr)
106+
require.Equal(t, tc.expAddress, addr.String())
107107
}
108108
})
109109
}
@@ -135,6 +135,80 @@ func BenchmarkDialerHosts(b *testing.B) {
135135
}
136136
}
137137

138+
func TestDialerResolveAddr(t *testing.T) {
139+
t.Parallel()
140+
dialer := NewDialer(net.Dialer{}, newResolver())
141+
hosts, err := types.NewHosts(
142+
map[string]types.Host{
143+
"example.com": {IP: net.ParseIP("3.4.5.6")},
144+
"example.com:443": {IP: net.ParseIP("3.4.5.6"), Port: 8443},
145+
"example.com:8080": {IP: net.ParseIP("3.4.5.6"), Port: 9090},
146+
"example-deny-host.com": {IP: net.ParseIP("8.9.10.11")},
147+
"example-ipv6.com": {IP: net.ParseIP("2001:db8::68")},
148+
"example-ipv6.com:443": {IP: net.ParseIP("2001:db8::68"), Port: 8443},
149+
"example-ipv6-deny-host.com": {IP: net.ParseIP("::1")},
150+
})
151+
require.NoError(t, err)
152+
dialer.Hosts = hosts
153+
ipNet, err := lib.ParseCIDR("8.9.10.0/24")
154+
require.NoError(t, err)
155+
156+
ipV6Net, err := lib.ParseCIDR("::1/24")
157+
require.NoError(t, err)
158+
159+
dialer.Blacklist = []*lib.IPNet{ipNet, ipV6Net}
160+
161+
testCases := []struct {
162+
name string
163+
address string
164+
expectedIP string
165+
expectedPort int
166+
expectedError string
167+
}{
168+
// IPv4 with port
169+
{"IPv4_with_resolver", "example-resolver.com:80", "1.2.3.4", 80, ""},
170+
{"IPv4_with_hosts_mapping", "example.com:80", "3.4.5.6", 80, ""},
171+
{"IPv4_with_custom_port_in_hosts", "example.com:443", "3.4.5.6", 8443, ""},
172+
{"IPv4_with_different_custom_port", "example.com:8080", "3.4.5.6", 9090, ""},
173+
{"IPv4_direct_IP", "1.2.3.4:80", "1.2.3.4", 80, ""},
174+
{"IPv4_blacklisted_via_resolver", "example-deny-resolver.com:80", "", 0, "IP (8.9.10.11) is in a blacklisted range (8.9.10.0/24)"},
175+
{"IPv4_blacklisted_via_hosts", "example-deny-host.com:80", "", 0, "IP (8.9.10.11) is in a blacklisted range (8.9.10.0/24)"},
176+
{"IPv4_non-existent_host", "no-such-host.com:80", "", 0, "lookup no-such-host.com: no such host"},
177+
178+
// IPv4 without port
179+
{"IPv4_without_port", "example-resolver.com", "1.2.3.4", 0, ""},
180+
{"IPv4_hosts_mapping_without_port", "example.com", "3.4.5.6", 0, ""},
181+
{"IPv4_direct_IP_without_port", "1.2.3.4", "1.2.3.4", 0, ""},
182+
183+
// IPv6 with port
184+
{"IPv6_with_hosts_mapping", "example-ipv6.com:443", "2001:db8::68", 8443, ""},
185+
{"IPv6_direct_IP", "[2001:db8:aaaa:1::100]:443", "2001:db8:aaaa:1::100", 443, ""},
186+
{"IPv6_blacklisted_via_resolver", "example-ipv6-deny-resolver.com:80", "", 0, "IP (::1) is in a blacklisted range (::/24)"},
187+
{"IPv6_blacklisted_via_hosts", "example-ipv6-deny-host.com:80", "", 0, "IP (::1) is in a blacklisted range (::/24)"},
188+
189+
// IPv6 without port
190+
{"IPv6_without_port", "example-ipv6.com", "2001:db8::68", 0, ""},
191+
{"IPv6_direct_IP_without_port", "2001:db8:aaaa:1::100", "2001:db8:aaaa:1::100", 0, ""},
192+
}
193+
194+
for _, tc := range testCases {
195+
t.Run(tc.name, func(t *testing.T) {
196+
t.Parallel()
197+
ip, port, err := dialer.ResolveAddr(tc.address)
198+
199+
if tc.expectedError != "" {
200+
require.EqualError(t, err, tc.expectedError)
201+
require.Nil(t, ip)
202+
require.Equal(t, 0, port)
203+
} else {
204+
require.NoError(t, err)
205+
require.Equal(t, net.ParseIP(tc.expectedIP), ip)
206+
require.Equal(t, tc.expectedPort, port)
207+
}
208+
})
209+
}
210+
}
211+
138212
func newResolver() *mockresolver.MockResolver {
139213
return mockresolver.New(
140214
map[string][]net.IP{

lib/vu_state.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ type TracerProvider interface {
2626
Tracer(name string, options ...trace.TracerOption) trace.Tracer
2727
}
2828

29+
// AddrResolver is an interface for DNS resolution.
30+
type AddrResolver interface {
31+
// ResolveAddr looks up the IP address for the given host and optionally port.
32+
// It uses the same DNS resolution logic as DialContext, respecting the
33+
// DNS, Hosts, Blacklist IP, and Block hostnames options.
34+
// The address can be in the form "host:port" or just "host".
35+
// It returns the resolved IP, the port (0 if not specified), and an error if resolution fails.
36+
ResolveAddr(addr string) (net.IP, int, error)
37+
}
38+
2939
// State provides the volatile state for a VU.
3040
//
3141
// TODO: rename to VUState or, better yet, move to some other Go package outside
@@ -92,6 +102,16 @@ type State struct {
92102
TestStatus *TestStatus
93103
}
94104

105+
// GetAddrResolver returns the AddrResolver implementation or nil if not available.
106+
func (s *State) GetAddrResolver() AddrResolver {
107+
resolver, ok := s.Dialer.(AddrResolver)
108+
if !ok {
109+
return nil
110+
}
111+
112+
return resolver
113+
}
114+
95115
// VUStateTags wraps the current VU's tags and ensures a thread-safe way to
96116
// access and modify them exists. This is necessary because the VU tags and
97117
// metadata can be modified from the JS scripts via the `vu.tags` API in the

lib/vu_state_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package lib
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
// mockDialerWithResolver is a mock that implements both DialContexter and AddrResolver
13+
type mockDialerWithResolver struct{}
14+
15+
func (m *mockDialerWithResolver) DialContext(_ context.Context, _, _ string) (net.Conn, error) {
16+
return nil, errors.ErrUnsupported
17+
}
18+
19+
func (m *mockDialerWithResolver) ResolveAddr(_ string) (net.IP, int, error) {
20+
return nil, 0, errors.ErrUnsupported
21+
}
22+
23+
// mockDialerWithoutResolver is a mock that only implements DialContexter
24+
type mockDialerWithoutResolver struct{}
25+
26+
func (m *mockDialerWithoutResolver) DialContext(_ context.Context, _, _ string) (net.Conn, error) {
27+
return nil, errors.ErrUnsupported
28+
}
29+
30+
func TestGetAddrResolver(t *testing.T) {
31+
t.Parallel()
32+
33+
t.Run("returns_same_instance_when_Dialer_implements_AddrResolver", func(t *testing.T) {
34+
t.Parallel()
35+
36+
mock := &mockDialerWithResolver{}
37+
38+
state := &State{
39+
Dialer: mock,
40+
}
41+
42+
resolver := state.GetAddrResolver()
43+
require.NotNil(t, resolver)
44+
require.Same(t, mock, resolver)
45+
})
46+
47+
t.Run("returns_nil_when_Dialer_does_not_implement_AddrResolver", func(t *testing.T) {
48+
t.Parallel()
49+
50+
mock := &mockDialerWithoutResolver{}
51+
52+
state := &State{
53+
Dialer: mock,
54+
}
55+
56+
resolver := state.GetAddrResolver()
57+
require.Nil(t, resolver)
58+
})
59+
}

0 commit comments

Comments
 (0)