Skip to content

Commit e82e917

Browse files
authored
Merge pull request #244 from imgk/master
replace string with netip.AddrPort
2 parents 43c8294 + c2cf969 commit e82e917

File tree

2 files changed

+110
-23
lines changed

2 files changed

+110
-23
lines changed

shadowaead/packet.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"io"
77
"net"
8+
"net/netip"
89
"sync"
910

1011
"github.com/shadowsocks/go-shadowsocks2/internal"
@@ -73,6 +74,9 @@ type packetConn struct {
7374
// NewPacketConn wraps a net.PacketConn with cipher
7475
func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn {
7576
const maxPacketSize = 64 * 1024
77+
if cc, ok := c.(*net.UDPConn); ok {
78+
return &udpConn{UDPConn: cc, Cipher: ciph, buf: make([]byte, maxPacketSize)}
79+
}
7680
return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, maxPacketSize)}
7781
}
7882

@@ -101,3 +105,62 @@ func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
101105
copy(b, bb)
102106
return len(bb), addr, err
103107
}
108+
109+
type udpConn struct {
110+
*net.UDPConn
111+
Cipher
112+
sync.Mutex
113+
buf []byte // write lock
114+
}
115+
116+
// WriteTo encrypts b and write to addr using the embedded UDPConn.
117+
func (c *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
118+
c.Lock()
119+
defer c.Unlock()
120+
buf, err := Pack(c.buf, b, c)
121+
if err != nil {
122+
return 0, err
123+
}
124+
_, err = c.UDPConn.WriteTo(buf, addr)
125+
return len(b), err
126+
}
127+
128+
// ReadFrom reads from the embedded UDPConn and decrypts into b.
129+
func (c *udpConn) ReadFrom(b []byte) (int, net.Addr, error) {
130+
n, addr, err := c.UDPConn.ReadFrom(b)
131+
if err != nil {
132+
return n, addr, err
133+
}
134+
bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
135+
if err != nil {
136+
return n, addr, err
137+
}
138+
copy(b, bb)
139+
return len(bb), addr, err
140+
}
141+
142+
// WriteToUDPAddrPort encrypts b and write to addr using the embedded PacketConn.
143+
func (c *udpConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
144+
c.Lock()
145+
defer c.Unlock()
146+
buf, err := Pack(c.buf, b, c)
147+
if err != nil {
148+
return 0, err
149+
}
150+
_, err = c.UDPConn.WriteToUDPAddrPort(buf, addr)
151+
return len(b), err
152+
}
153+
154+
// ReadFromUDPAddrPort reads from the embedded UDPConn and decrypts into b.
155+
func (c *udpConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) {
156+
n, addr, err := c.UDPConn.ReadFromUDPAddrPort(b)
157+
if err != nil {
158+
return n, addr, err
159+
}
160+
bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
161+
if err != nil {
162+
return n, addr, err
163+
}
164+
copy(b, bb)
165+
return len(bb), addr, err
166+
}

udp.go

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"fmt"
55
"net"
6+
"net/netip"
67
"sync"
78
"time"
89

@@ -34,7 +35,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack
3435
return
3536
}
3637

37-
c, err := net.ListenPacket("udp", laddr)
38+
lnAddr, err := net.ResolveUDPAddr("udp", laddr)
39+
if err != nil {
40+
logf("UDP listen address error: %v", err)
41+
return
42+
}
43+
44+
c, err := net.ListenUDP("udp", lnAddr)
3845
if err != nil {
3946
logf("UDP local listen error: %v", err)
4047
return
@@ -47,13 +54,13 @@ func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.Pack
4754

4855
logf("UDP tunnel %s <-> %s <-> %s", laddr, server, target)
4956
for {
50-
n, raddr, err := c.ReadFrom(buf[len(tgt):])
57+
n, raddr, err := c.ReadFromUDPAddrPort(buf[len(tgt):])
5158
if err != nil {
5259
logf("UDP local read error: %v", err)
5360
continue
5461
}
5562

56-
pc := nm.Get(raddr.String())
63+
pc := nm.Get(raddr)
5764
if pc == nil {
5865
pc, err = net.ListenPacket("udp", "")
5966
if err != nil {
@@ -81,7 +88,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
8188
return
8289
}
8390

84-
c, err := net.ListenPacket("udp", laddr)
91+
lnAddr, err := net.ResolveUDPAddr("udp", laddr)
92+
if err != nil {
93+
logf("UDP listen address error: %v", err)
94+
return
95+
}
96+
97+
c, err := net.ListenUDP("udp", lnAddr)
8598
if err != nil {
8699
logf("UDP local listen error: %v", err)
87100
return
@@ -92,13 +105,13 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
92105
buf := make([]byte, udpBufSize)
93106

94107
for {
95-
n, raddr, err := c.ReadFrom(buf)
108+
n, raddr, err := c.ReadFromUDPAddrPort(buf)
96109
if err != nil {
97110
logf("UDP local read error: %v", err)
98111
continue
99112
}
100113

101-
pc := nm.Get(raddr.String())
114+
pc := nm.Get(raddr)
102115
if pc == nil {
103116
pc, err = net.ListenPacket("udp", "")
104117
if err != nil {
@@ -118,22 +131,33 @@ func udpSocksLocal(laddr, server string, shadow func(net.PacketConn) net.PacketC
118131
}
119132
}
120133

134+
type UDPConn interface {
135+
net.PacketConn
136+
ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error)
137+
WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
138+
}
139+
121140
// Listen on addr for encrypted packets and basically do UDP NAT.
122141
func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
123-
c, err := net.ListenPacket("udp", addr)
142+
nAddr, err := net.ResolveUDPAddr("udp", addr)
143+
if err != nil {
144+
logf("UDP server address error: %v", err)
145+
return
146+
}
147+
cc, err := net.ListenUDP("udp", nAddr)
124148
if err != nil {
125149
logf("UDP remote listen error: %v", err)
126150
return
127151
}
128-
defer c.Close()
129-
c = shadow(c)
152+
defer cc.Close()
153+
c := shadow(cc).(UDPConn)
130154

131155
nm := newNATmap(config.UDPTimeout)
132156
buf := make([]byte, udpBufSize)
133157

134158
logf("listening UDP on %s", addr)
135159
for {
136-
n, raddr, err := c.ReadFrom(buf)
160+
n, raddr, err := c.ReadFromUDPAddrPort(buf)
137161
if err != nil {
138162
logf("UDP remote read error: %v", err)
139163
continue
@@ -153,7 +177,7 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
153177

154178
payload := buf[len(tgtAddr):n]
155179

156-
pc := nm.Get(raddr.String())
180+
pc := nm.Get(raddr)
157181
if pc == nil {
158182
pc, err = net.ListenPacket("udp", "")
159183
if err != nil {
@@ -175,31 +199,31 @@ func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) {
175199
// Packet NAT table
176200
type natmap struct {
177201
sync.RWMutex
178-
m map[string]net.PacketConn
202+
m map[netip.AddrPort]net.PacketConn
179203
timeout time.Duration
180204
}
181205

182206
func newNATmap(timeout time.Duration) *natmap {
183207
m := &natmap{}
184-
m.m = make(map[string]net.PacketConn)
208+
m.m = make(map[netip.AddrPort]net.PacketConn)
185209
m.timeout = timeout
186210
return m
187211
}
188212

189-
func (m *natmap) Get(key string) net.PacketConn {
213+
func (m *natmap) Get(key netip.AddrPort) net.PacketConn {
190214
m.RLock()
191215
defer m.RUnlock()
192216
return m.m[key]
193217
}
194218

195-
func (m *natmap) Set(key string, pc net.PacketConn) {
219+
func (m *natmap) Set(key netip.AddrPort, pc net.PacketConn) {
196220
m.Lock()
197221
defer m.Unlock()
198222

199223
m.m[key] = pc
200224
}
201225

202-
func (m *natmap) Del(key string) net.PacketConn {
226+
func (m *natmap) Del(key netip.AddrPort) net.PacketConn {
203227
m.Lock()
204228
defer m.Unlock()
205229

@@ -211,19 +235,19 @@ func (m *natmap) Del(key string) net.PacketConn {
211235
return nil
212236
}
213237

214-
func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn, role mode) {
215-
m.Set(peer.String(), src)
238+
func (m *natmap) Add(peer netip.AddrPort, dst UDPConn, src net.PacketConn, role mode) {
239+
m.Set(peer, src)
216240

217241
go func() {
218242
timedCopy(dst, peer, src, m.timeout, role)
219-
if pc := m.Del(peer.String()); pc != nil {
243+
if pc := m.Del(peer); pc != nil {
220244
pc.Close()
221245
}
222246
}()
223247
}
224248

225249
// copy from src to dst at target with read timeout
226-
func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration, role mode) error {
250+
func timedCopy(dst UDPConn, target netip.AddrPort, src net.PacketConn, timeout time.Duration, role mode) error {
227251
buf := make([]byte, udpBufSize)
228252

229253
for {
@@ -238,12 +262,12 @@ func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout
238262
srcAddr := socks.ParseAddr(raddr.String())
239263
copy(buf[len(srcAddr):], buf[:n])
240264
copy(buf, srcAddr)
241-
_, err = dst.WriteTo(buf[:len(srcAddr)+n], target)
265+
_, err = dst.WriteToUDPAddrPort(buf[:len(srcAddr)+n], target)
242266
case relayClient: // client -> user: strip original packet source
243267
srcAddr := socks.SplitAddr(buf[:n])
244-
_, err = dst.WriteTo(buf[len(srcAddr):n], target)
268+
_, err = dst.WriteToUDPAddrPort(buf[len(srcAddr):n], target)
245269
case socksClient: // client -> socks5 program: just set RSV and FRAG = 0
246-
_, err = dst.WriteTo(append([]byte{0, 0, 0}, buf[:n]...), target)
270+
_, err = dst.WriteToUDPAddrPort(append([]byte{0, 0, 0}, buf[:n]...), target)
247271
}
248272

249273
if err != nil {

0 commit comments

Comments
 (0)