@@ -2,6 +2,7 @@ package networking
22
33import (
44 "fmt"
5+ "net"
56 "strconv"
67 "strings"
78
@@ -26,14 +27,15 @@ type DNSRedirectComponent struct {
2627 targetIPv4 string
2728 targetIPv6 string
2829 targetPort uint16
30+ interfaces []string
2931 ipt4 * iptables.IPTables
3032 ipt6 * iptables.IPTables
3133}
3234
3335// NewDNSRedirectComponent creates a new component for DNS redirection.
3436// listenAddr is the address the DNS proxy listens on (e.g., "[::]" for dual-stack).
3537// When listenAddr is "[::]" (dual-stack), redirects to localhost (::1 for IPv6, 127.0.0.1 for IPv4).
36- func NewDNSRedirectComponent (listenAddr string , targetPort uint16 ) (* DNSRedirectComponent , error ) {
38+ func NewDNSRedirectComponent (listenAddr string , targetPort uint16 , interfaces [] string ) (* DNSRedirectComponent , error ) {
3739 ipt4 , err := iptables .NewWithProtocol (iptables .ProtocolIPv4 )
3840 if err != nil {
3941 return nil , fmt .Errorf ("failed to create iptables (IPv4): %w" , err )
@@ -66,6 +68,7 @@ func NewDNSRedirectComponent(listenAddr string, targetPort uint16) (*DNSRedirect
6668 targetIPv4 : targetIPv4 ,
6769 targetIPv6 : targetIPv6 ,
6870 targetPort : targetPort ,
71+ interfaces : interfaces ,
6972 ipt4 : ipt4 ,
7073 ipt6 : ipt6 ,
7174 }, nil
@@ -113,7 +116,7 @@ func (c *DNSRedirectComponent) CreateIfNotExists() error {
113116 }
114117
115118 // Get all local addresses
116- addresses , err := getLocalAddresses ()
119+ addresses , err := c . getLocalAddresses ()
117120 if err != nil {
118121 return fmt .Errorf ("failed to get local addresses: %w" , err )
119122 }
@@ -225,47 +228,64 @@ func (c *DNSRedirectComponent) createChainAndRules(ipt *iptables.IPTables, addre
225228 }
226229 }
227230
228- // Get list of interfaces instead of addresses
229- links , err := netlink .LinkList ()
230- if err != nil {
231- return fmt .Errorf ("failed to list links: %w" , err )
232- }
233-
234- // Create a set of unique interface names to avoid duplicates
235- interfaceSet := make (map [string ]bool )
236- for _ , link := range links {
237- ifName := link .Attrs ().Name
238- // Skip loopback interface
239- if ifName == "lo" {
231+ // Add rules for each address
232+ for _ , addr := range addresses {
233+ // Check if address matches current iptables protocol
234+ isIPv4 := len (addr .IP ) == net .IPv4len
235+ if ipt .Proto () == iptables .ProtocolIPv4 && ! isIPv4 {
240236 continue
241237 }
242- interfaceSet [ifName ] = true
243- }
244-
245- // Add rules for each interface
246- // Match on incoming interface + destination port 53, redirect to target port
247- // This works for all IP addresses (IPv4, global IPv6, link-local IPv6)
248- for ifName := range interfaceSet {
249- udpRule := []string {
250- "-i" , ifName ,
251- "-p" , "udp" ,
252- "--dport" , strconv .Itoa (dnsSourcePort ),
253- "-j" , "REDIRECT" ,
254- "--to-ports" , strconv .Itoa (int (c .targetPort )),
255- }
256- if err := ipt .AppendUnique ("nat" , dnsRedirectChainName , udpRule ... ); err != nil {
257- return fmt .Errorf ("failed to add UDP rule for %s: %w" , ifName , err )
238+ if ipt .Proto () == iptables .ProtocolIPv6 && isIPv4 {
239+ continue
258240 }
259241
260- tcpRule := []string {
261- "-i" , ifName ,
262- "-p" , "tcp" ,
263- "--dport" , strconv .Itoa (dnsSourcePort ),
264- "-j" , "REDIRECT" ,
265- "--to-ports" , strconv .Itoa (int (c .targetPort )),
266- }
267- if err := ipt .AppendUnique ("nat" , dnsRedirectChainName , tcpRule ... ); err != nil {
268- return fmt .Errorf ("failed to add TCP rule for %s: %w" , ifName , err )
242+ // IPv4: Use REDIRECT with explicit destination IP
243+ if ipt .Proto () == iptables .ProtocolIPv4 {
244+ udpRule := []string {
245+ "-d" , addr .IP .String (),
246+ "-p" , "udp" ,
247+ "--dport" , strconv .Itoa (dnsSourcePort ),
248+ "-j" , "REDIRECT" ,
249+ "--to-ports" , strconv .Itoa (int (c .targetPort )),
250+ }
251+ if err := ipt .AppendUnique ("nat" , dnsRedirectChainName , udpRule ... ); err != nil {
252+ return fmt .Errorf ("failed to add UDP rule for %s: %w" , addr .IP , err )
253+ }
254+
255+ tcpRule := []string {
256+ "-d" , addr .IP .String (),
257+ "-p" , "tcp" ,
258+ "--dport" , strconv .Itoa (dnsSourcePort ),
259+ "-j" , "REDIRECT" ,
260+ "--to-ports" , strconv .Itoa (int (c .targetPort )),
261+ }
262+ if err := ipt .AppendUnique ("nat" , dnsRedirectChainName , tcpRule ... ); err != nil {
263+ return fmt .Errorf ("failed to add TCP rule for %s: %w" , addr .IP , err )
264+ }
265+ } else {
266+ // IPv6: Use DNAT to :port (preserves destination IP)
267+ // This fixes Source IP selection for Link-Local addresses
268+ udpRule := []string {
269+ "-d" , addr .IP .String () + "/128" ,
270+ "-p" , "udp" ,
271+ "--dport" , strconv .Itoa (dnsSourcePort ),
272+ "-j" , "DNAT" ,
273+ "--to-destination" , fmt .Sprintf (":%d" , c .targetPort ),
274+ }
275+ if err := ipt .AppendUnique ("nat" , dnsRedirectChainName , udpRule ... ); err != nil {
276+ return fmt .Errorf ("failed to add UDP rule for %s: %w" , addr .IP , err )
277+ }
278+
279+ tcpRule := []string {
280+ "-d" , addr .IP .String () + "/128" ,
281+ "-p" , "tcp" ,
282+ "--dport" , strconv .Itoa (dnsSourcePort ),
283+ "-j" , "DNAT" ,
284+ "--to-destination" , fmt .Sprintf (":%d" , c .targetPort ),
285+ }
286+ if err := ipt .AppendUnique ("nat" , dnsRedirectChainName , tcpRule ... ); err != nil {
287+ return fmt .Errorf ("failed to add TCP rule for %s: %w" , addr .IP , err )
288+ }
269289 }
270290 }
271291
@@ -294,15 +314,26 @@ func (c *DNSRedirectComponent) deleteChainAndRules(ipt *iptables.IPTables) {
294314 }
295315}
296316
297- // getLocalAddresses returns all local IP addresses.
298- func getLocalAddresses () ([]netlink.Addr , error ) {
317+ // getLocalAddresses returns all local IP addresses for configured interfaces .
318+ func ( c * DNSRedirectComponent ) getLocalAddresses () ([]netlink.Addr , error ) {
299319 links , err := netlink .LinkList ()
300320 if err != nil {
301321 return nil , fmt .Errorf ("failed to list links: %w" , err )
302322 }
303323
324+ // Create map for fast lookup of configured interfaces
325+ configuredInterfaces := make (map [string ]bool )
326+ for _ , iface := range c .interfaces {
327+ configuredInterfaces [iface ] = true
328+ }
329+
304330 var addresses []netlink.Addr
305331 for _ , link := range links {
332+ // Filter interfaces
333+ if ! configuredInterfaces [link .Attrs ().Name ] {
334+ continue
335+ }
336+
306337 addrs , err := netlink .AddrList (link , netlink .FAMILY_ALL )
307338 if err != nil {
308339 log .Debugf ("Failed to get addresses for %s: %v" , link .Attrs ().Name , err )
0 commit comments