@@ -3,6 +3,7 @@ package main
33import (
44 "fmt"
55 "net"
6+ "net/netip"
67 "sync"
78 "time"
89
@@ -34,7 +35,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack
3435 return
3536 }
3637
37- c , err := net .ListenPacket ("udp" , laddr )
38+ lnAddr , err := net .ResolveUDPAddr ("udp" , laddr )
39+ if err != nil {
40+ logf ("UDP listen address error: %v" , err )
41+ return
42+ }
43+
44+ c , err := net .ListenUDP ("udp" , lnAddr )
3845 if err != nil {
3946 logf ("UDP local listen error: %v" , err )
4047 return
@@ -47,13 +54,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack
4754
4855 logf ("UDP tunnel %s <-> %s <-> %s" , laddr , server , target )
4956 for {
50- n , raddr , err := c .ReadFrom (buf [len (tgt ):])
57+ n , raddr , err := c .ReadFromUDPAddrPort (buf [len (tgt ):])
5158 if err != nil {
5259 logf ("UDP local read error: %v" , err )
5360 continue
5461 }
5562
56- pc := nm .Get (raddr . String () )
63+ pc := nm .Get (raddr )
5764 if pc == nil {
5865 pc , err = net .ListenPacket ("udp" , "" )
5966 if err != nil {
@@ -81,7 +88,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
8188 return
8289 }
8390
84- c , err := net .ListenPacket ("udp" , laddr )
91+ lnAddr , err := net .ResolveUDPAddr ("udp" , laddr )
92+ if err != nil {
93+ logf ("UDP listen address error: %v" , err )
94+ return
95+ }
96+
97+ c , err := net .ListenUDP ("udp" , lnAddr )
8598 if err != nil {
8699 logf ("UDP local listen error: %v" , err )
87100 return
@@ -92,13 +105,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
92105 buf := make ([]byte , udpBufSize )
93106
94107 for {
95- n , raddr , err := c .ReadFrom (buf )
108+ n , raddr , err := c .ReadFromUDPAddrPort (buf )
96109 if err != nil {
97110 logf ("UDP local read error: %v" , err )
98111 continue
99112 }
100113
101- pc := nm .Get (raddr . String () )
114+ pc := nm .Get (raddr )
102115 if pc == nil {
103116 pc , err = net .ListenPacket ("udp" , "" )
104117 if err != nil {
@@ -118,22 +131,33 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
118131 }
119132}
120133
134+ type UDPConn interface {
135+ net.PacketConn
136+ ReadFromUDPAddrPort ([]byte ) (int , netip.AddrPort , error )
137+ WriteToUDPAddrPort ([]byte , netip.AddrPort ) (int , error )
138+ }
139+
121140// Listen on addr for encrypted packets and basically do UDP NAT.
122141func udpRemote (addr string , shadow func (net.PacketConn ) net.PacketConn ) {
123- c , err := net .ListenPacket ("udp" , addr )
142+ nAddr , err := net .ResolveUDPAddr ("udp" , addr )
143+ if err != nil {
144+ logf ("UDP server address error: %v" , err )
145+ return
146+ }
147+ cc , err := net .ListenUDP ("udp" , nAddr )
124148 if err != nil {
125149 logf ("UDP remote listen error: %v" , err )
126150 return
127151 }
128- defer c .Close ()
129- c = shadow (c )
152+ defer cc .Close ()
153+ c : = shadow (cc ).( UDPConn )
130154
131155 nm := newNATmap (config .UDPTimeout )
132156 buf := make ([]byte , udpBufSize )
133157
134158 logf ("listening UDP on %s" , addr )
135159 for {
136- n , raddr , err := c .ReadFrom (buf )
160+ n , raddr , err := c .ReadFromUDPAddrPort (buf )
137161 if err != nil {
138162 logf ("UDP remote read error: %v" , err )
139163 continue
@@ -153,7 +177,7 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
153177
154178 payload := buf [len (tgtAddr ):n ]
155179
156- pc := nm .Get (raddr . String () )
180+ pc := nm .Get (raddr )
157181 if pc == nil {
158182 pc , err = net .ListenPacket ("udp" , "" )
159183 if err != nil {
@@ -175,31 +199,31 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
175199// Packet NAT table
176200type natmap struct {
177201 sync.RWMutex
178- m map [string ]net.PacketConn
202+ m map [netip. AddrPort ]net.PacketConn
179203 timeout time.Duration
180204}
181205
182206func newNATmap (timeout time.Duration ) * natmap {
183207 m := & natmap {}
184- m .m = make (map [string ]net.PacketConn )
208+ m .m = make (map [netip. AddrPort ]net.PacketConn )
185209 m .timeout = timeout
186210 return m
187211}
188212
189- func (m * natmap ) Get (key string ) net.PacketConn {
213+ func (m * natmap ) Get (key netip. AddrPort ) net.PacketConn {
190214 m .RLock ()
191215 defer m .RUnlock ()
192216 return m .m [key ]
193217}
194218
195- func (m * natmap ) Set (key string , pc net.PacketConn ) {
219+ func (m * natmap ) Set (key netip. AddrPort , pc net.PacketConn ) {
196220 m .Lock ()
197221 defer m .Unlock ()
198222
199223 m .m [key ] = pc
200224}
201225
202- func (m * natmap ) Del (key string ) net.PacketConn {
226+ func (m * natmap ) Del (key netip. AddrPort ) net.PacketConn {
203227 m .Lock ()
204228 defer m .Unlock ()
205229
@@ -211,19 +235,19 @@ func (m *natmap) Del(key string) net.PacketConn {
211235 return nil
212236}
213237
214- func (m * natmap ) Add (peer net. Addr , dst , src net.PacketConn , role mode ) {
215- m .Set (peer . String () , src )
238+ func (m * natmap ) Add (peer netip. AddrPort , dst UDPConn , src net.PacketConn , role mode ) {
239+ m .Set (peer , src )
216240
217241 go func () {
218242 timedCopy (dst , peer , src , m .timeout , role )
219- if pc := m .Del (peer . String () ); pc != nil {
243+ if pc := m .Del (peer ); pc != nil {
220244 pc .Close ()
221245 }
222246 }()
223247}
224248
225249// copy from src to dst at target with read timeout
226- func timedCopy (dst net. PacketConn , target net. Addr , src net.PacketConn , timeout time.Duration , role mode ) error {
250+ func timedCopy (dst UDPConn , target netip. AddrPort , src net.PacketConn , timeout time.Duration , role mode ) error {
227251 buf := make ([]byte , udpBufSize )
228252
229253 for {
@@ -238,12 +262,12 @@ func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout
238262 srcAddr := socks .ParseAddr (raddr .String ())
239263 copy (buf [len (srcAddr ):], buf [:n ])
240264 copy (buf , srcAddr )
241- _ , err = dst .WriteTo (buf [:len (srcAddr )+ n ], target )
265+ _ , err = dst .WriteToUDPAddrPort (buf [:len (srcAddr )+ n ], target )
242266 case relayClient : // client -> user: strip original packet source
243267 srcAddr := socks .SplitAddr (buf [:n ])
244- _ , err = dst .WriteTo (buf [len (srcAddr ):n ], target )
268+ _ , err = dst .WriteToUDPAddrPort (buf [len (srcAddr ):n ], target )
245269 case socksClient : // client -> socks5 program: just set RSV and FRAG = 0
246- _ , err = dst .WriteTo (append ([]byte {0 , 0 , 0 }, buf [:n ]... ), target )
270+ _ , err = dst .WriteToUDPAddrPort (append ([]byte {0 , 0 , 0 }, buf [:n ]... ), target )
247271 }
248272
249273 if err != nil {
0 commit comments