Skip to content

Commit 30747d4

Browse files
committed
add http、http2、socks5、relay access speed limiter
1 parent 45340b2 commit 30747d4

File tree

9 files changed

+419
-2
lines changed

9 files changed

+419
-2
lines changed

cmd/gost/cfg.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,26 @@ func parseAuthenticator(s string) (gost.Authenticator, error) {
144144
return au, nil
145145
}
146146

147+
func parseLimiter(s string) (gost.Limiter, error) {
148+
if s == "" {
149+
return nil, nil
150+
}
151+
f, err := os.Open(s)
152+
if err != nil {
153+
return nil, err
154+
}
155+
defer f.Close()
156+
157+
l, _ := gost.NewLocalLimiter("", "")
158+
err = l.Reload(f)
159+
if err != nil {
160+
return nil, err
161+
}
162+
go gost.PeriodReload(l, s)
163+
164+
return l, nil
165+
}
166+
147167
func parseIP(s string, port string) (ips []string) {
148168
if s == "" {
149169
return

cmd/gost/route.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,19 @@ func (r *route) GenRouters() ([]router, error) {
375375
node.User = users[0]
376376
}
377377
}
378+
379+
//init rate limiter
380+
limiterHandler, err := parseLimiter(node.Get("secrets"))
381+
if err != nil {
382+
return nil, err
383+
}
384+
if limiterHandler == nil && strings.TrimSpace(node.Get("limiter")) != "" && node.User != nil {
385+
limiterHandler, err = gost.NewLocalLimiter(node.User.Username(), strings.TrimSpace(node.Get("limiter")))
386+
if err != nil {
387+
return nil, err
388+
}
389+
}
390+
378391
certFile, keyFile := node.Get("cert"), node.Get("key")
379392
tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca"))
380393
if err != nil && certFile != "" && keyFile != "" {
@@ -650,6 +663,7 @@ func (r *route) GenRouters() ([]router, error) {
650663
gost.IPsHandlerOption(ips),
651664
gost.TCPModeHandlerOption(node.GetBool("tcp")),
652665
gost.IPRoutesHandlerOption(tunRoutes...),
666+
gost.LimiterHandlerOption(limiterHandler),
653667
)
654668

655669
rt := router{

handler.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type HandlerOptions struct {
4242
IPs []string
4343
TCPMode bool
4444
IPRoutes []IPRoute
45+
Limiter Limiter
4546
}
4647

4748
// HandlerOption allows a common way to set handler options.
@@ -85,6 +86,13 @@ func AuthenticatorHandlerOption(au Authenticator) HandlerOption {
8586
}
8687
}
8788

89+
// LimiterHandlerOption sets the Rate limiter option of HandlerOptions
90+
func LimiterHandlerOption(l Limiter) HandlerOption {
91+
return func(opts *HandlerOptions) {
92+
opts.Limiter = l
93+
}
94+
}
95+
8896
// TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions.
8997
func TLSConfigHandlerOption(config *tls.Config) HandlerOption {
9098
return func(opts *HandlerOptions) {

http.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,23 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
206206
if !h.authenticate(conn, req, resp) {
207207
return
208208
}
209+
user, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization"))
210+
if h.options.Limiter != nil {
211+
done, ok := h.options.Limiter.CheckRate(user, true)
212+
if !ok {
213+
resp.StatusCode = http.StatusTooManyRequests
209214

215+
if Debug {
216+
dump, _ := httputil.DumpResponse(resp, false)
217+
log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
218+
}
219+
220+
resp.Write(conn)
221+
return
222+
} else {
223+
defer done()
224+
}
225+
}
210226
if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
211227
resp.StatusCode = http.StatusBadRequest
212228

http2.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,18 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
391391
if !h.authenticate(w, r, resp) {
392392
return
393393
}
394-
394+
user, _, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
395+
if h.options.Limiter != nil {
396+
done, ok := h.options.Limiter.CheckRate(user, true)
397+
if !ok {
398+
log.Logf("[http2] %s - %s rate limiter %s, user is %s",
399+
r.RemoteAddr, laddr, host, user)
400+
w.WriteHeader(http.StatusTooManyRequests)
401+
return
402+
} else {
403+
defer done()
404+
}
405+
}
395406
// delete the proxy related headers.
396407
r.Header.Del("Proxy-Authorization")
397408
r.Header.Del("Proxy-Connection")

limiter.go

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
package gost
2+
3+
import (
4+
"bufio"
5+
"errors"
6+
"io"
7+
"strconv"
8+
"strings"
9+
"sync"
10+
"sync/atomic"
11+
"time"
12+
)
13+
14+
type Limiter interface {
15+
CheckRate(key string, checkConcurrent bool) (func(), bool)
16+
}
17+
18+
func NewLocalLimiter(user string, cfg string) (*LocalLimiter, error) {
19+
limiter := LocalLimiter{
20+
buckets: map[string]*limiterBucket{},
21+
concurrent: map[string]chan bool{},
22+
stopped: make(chan struct{}),
23+
}
24+
if cfg == "" || user == "" {
25+
return &limiter, nil
26+
}
27+
if err := limiter.AddRule(user, cfg); err != nil {
28+
return nil, err
29+
}
30+
return &limiter, nil
31+
}
32+
33+
// Token Bucket
34+
type limiterBucket struct {
35+
max int64
36+
cur int64
37+
duration int64
38+
batch int64
39+
}
40+
41+
type LocalLimiter struct {
42+
buckets map[string]*limiterBucket
43+
concurrent map[string]chan bool
44+
mux sync.RWMutex
45+
stopped chan struct{}
46+
period time.Duration
47+
}
48+
49+
func (l *LocalLimiter) CheckRate(key string, checkConcurrent bool) (func(), bool) {
50+
if checkConcurrent {
51+
done, ok := l.checkConcurrent(key)
52+
if !ok {
53+
return nil, false
54+
}
55+
if t := l.getToken(key); !t {
56+
done()
57+
return nil, false
58+
}
59+
return done, true
60+
} else {
61+
if t := l.getToken(key); !t {
62+
return nil, false
63+
}
64+
return nil, true
65+
}
66+
}
67+
68+
func (l *LocalLimiter) AddRule(user string, cfg string) error {
69+
if user == "" {
70+
return nil
71+
}
72+
if cfg == "" {
73+
//reload need check old limit exists
74+
if _, ok := l.buckets[user]; ok {
75+
delete(l.buckets, user)
76+
}
77+
if _, ok := l.concurrent[user]; ok {
78+
delete(l.concurrent, user)
79+
}
80+
return nil
81+
}
82+
args := strings.Split(cfg, ",")
83+
if len(args) < 2 || len(args) > 3 {
84+
return errors.New("parse limiter fail:" + cfg)
85+
}
86+
if len(args) == 2 {
87+
args = append(args, "0")
88+
}
89+
90+
duration, e1 := strconv.ParseInt(strings.TrimSpace(args[0]), 10, 64)
91+
count, e2 := strconv.ParseInt(strings.TrimSpace(args[1]), 10, 64)
92+
cur, e3 := strconv.ParseInt(strings.TrimSpace(args[2]), 10, 64)
93+
if e1 != nil || e2 != nil || e3 != nil {
94+
return errors.New("parse limiter fail:" + cfg)
95+
}
96+
// 0 means not limit
97+
if duration > 0 && count > 0 {
98+
bu := &limiterBucket{
99+
cur: count * 10,
100+
max: count * 10,
101+
duration: duration * 100,
102+
batch: count,
103+
}
104+
go func() {
105+
for {
106+
time.Sleep(time.Millisecond * time.Duration(bu.duration))
107+
if bu.cur+bu.batch > bu.max {
108+
bu.cur = bu.max
109+
} else {
110+
atomic.AddInt64(&bu.cur, bu.batch)
111+
}
112+
}
113+
}()
114+
l.buckets[user] = bu
115+
} else {
116+
if _, ok := l.buckets[user]; ok {
117+
delete(l.buckets, user)
118+
}
119+
}
120+
// zero means not limit
121+
if cur > 0 {
122+
l.concurrent[user] = make(chan bool, cur)
123+
} else {
124+
if _, ok := l.concurrent[user]; ok {
125+
delete(l.concurrent, user)
126+
}
127+
}
128+
return nil
129+
}
130+
131+
// Reload parses config from r, then live reloads the LocalLimiter.
132+
func (l *LocalLimiter) Reload(r io.Reader) error {
133+
var period time.Duration
134+
kvs := make(map[string]string)
135+
136+
if r == nil || l.Stopped() {
137+
return nil
138+
}
139+
140+
// splitLine splits a line text by white space.
141+
// A line started with '#' will be ignored, otherwise it is valid.
142+
split := func(line string) []string {
143+
if line == "" {
144+
return nil
145+
}
146+
line = strings.Replace(line, "\t", " ", -1)
147+
line = strings.TrimSpace(line)
148+
149+
if strings.IndexByte(line, '#') == 0 {
150+
return nil
151+
}
152+
153+
var ss []string
154+
for _, s := range strings.Split(line, " ") {
155+
if s = strings.TrimSpace(s); s != "" {
156+
ss = append(ss, s)
157+
}
158+
}
159+
return ss
160+
}
161+
162+
scanner := bufio.NewScanner(r)
163+
for scanner.Scan() {
164+
line := scanner.Text()
165+
ss := split(line)
166+
if len(ss) == 0 {
167+
continue
168+
}
169+
170+
switch ss[0] {
171+
case "reload": // reload option
172+
if len(ss) > 1 {
173+
period, _ = time.ParseDuration(ss[1])
174+
}
175+
default:
176+
var k, v string
177+
k = ss[0]
178+
if len(ss) > 2 {
179+
v = ss[2]
180+
}
181+
kvs[k] = v
182+
}
183+
}
184+
185+
if err := scanner.Err(); err != nil {
186+
return err
187+
}
188+
189+
l.mux.Lock()
190+
defer l.mux.Unlock()
191+
192+
l.period = period
193+
for user, args := range kvs {
194+
err := l.AddRule(user, args)
195+
if err != nil {
196+
return err
197+
}
198+
}
199+
200+
return nil
201+
}
202+
203+
// Period returns the reload period.
204+
func (l *LocalLimiter) Period() time.Duration {
205+
if l.Stopped() {
206+
return -1
207+
}
208+
209+
l.mux.RLock()
210+
defer l.mux.RUnlock()
211+
212+
return l.period
213+
}
214+
215+
// Stop stops reloading.
216+
func (l *LocalLimiter) Stop() {
217+
select {
218+
case <-l.stopped:
219+
default:
220+
close(l.stopped)
221+
}
222+
}
223+
224+
// Stopped checks whether the reloader is stopped.
225+
func (l *LocalLimiter) Stopped() bool {
226+
select {
227+
case <-l.stopped:
228+
return true
229+
default:
230+
return false
231+
}
232+
}
233+
234+
func (l *LocalLimiter) getToken(key string) bool {
235+
b, ok := l.buckets[key]
236+
if !ok || b == nil {
237+
return true
238+
}
239+
if b.cur <= 0 {
240+
return false
241+
}
242+
atomic.AddInt64(&b.cur, -10)
243+
return true
244+
}
245+
246+
func (l *LocalLimiter) checkConcurrent(key string) (func(), bool) {
247+
c, ok := l.concurrent[key]
248+
if !ok || c == nil {
249+
return func() {}, true
250+
}
251+
select {
252+
case c <- true:
253+
return func() {
254+
<-c
255+
}, true
256+
default:
257+
return nil, false
258+
}
259+
}

0 commit comments

Comments
 (0)