@@ -14,10 +14,13 @@ import (
1414 "syscall"
1515
1616 "encoding/json"
17- "fmt"
1817 "net"
1918 "net/http"
2019
20+ "strings"
21+
22+ "fmt"
23+
2124 "github.com/miekg/dns"
2225 "github.com/wrouesnel/go.log"
2326)
2831 defaultServer = flag .String ("default" , "https://dns.google.com/resolve" ,
2932 "DNS-over-HTTPS service endpoint" )
3033
34+ prefixServer = flag .String ("primary-dns" , "" ,
35+ "If set all DNS queries are attempted against this DNS server first before trying HTTPS" )
36+
37+ suffixServer = flag .String ("fallback-dns" , "" ,
38+ "If set all failed (i.e. NXDOMAIN and others) DNS queries are attempted against this DNS server after HTTPS fails." ) // nolint: lll
39+
40+ fallthroughStatuses = flag .String ("fallthrough-statuses" , "NXDOMAIN" ,
41+ "Comma-separated list of statuses which should cause server fallthrough" )
42+ neverDefault = flag .String ("no-fallthrough" , "" ,
43+ "Comma-separated list of suffixes which will not be allowed to fallthrough (most useful with prefix DNS" )
44+
3145 //routeList = flag.String("route", "",
3246 // "List of routes where to send queries (subdomain=IP:port)")
3347 //routes map[string]string
@@ -157,7 +171,20 @@ func route(w dns.ResponseWriter, req *dns.Msg) {
157171 // return
158172 // }
159173 //}
160- proxy (* defaultServer , w , req )
174+
175+ fallthroughs := make (map [int ]struct {})
176+ for _ , v := range strings .Split (* fallthroughStatuses , "," ) {
177+ rcode , found := dns .StringToRcode [v ]
178+ if ! found {
179+ log .Fatalln ("Could not find matching Rcode integer for" , v )
180+ }
181+
182+ fallthroughs [rcode ] = struct {}{}
183+ }
184+
185+ noFallthrough := strings .Split (* neverDefault , "," )
186+
187+ proxy (* defaultServer , * prefixServer , * suffixServer , fallthroughs , noFallthrough , w , req )
161188}
162189
163190//func isTransfer(req *dns.Msg) bool {
@@ -183,41 +210,17 @@ func route(w dns.ResponseWriter, req *dns.Msg) {
183210// return false
184211//}
185212
186- func proxy (addr string , w dns.ResponseWriter , req * dns.Msg ) {
187- var err error
188- //transport := "udp"
189- //if _, ok := w.RemoteAddr().(*net.TCPAddr); ok {
190- // transport = "tcp"
191- //}
192- //if isTransfer(req) {
193- // if transport != "tcp" {
194- // dns.HandleFailed(w, req)
195- // return
196- // }
197- // t := new(dns.Transfer)
198- // c, err := t.In(req, addr)
199- // if err != nil {
200- // dns.HandleFailed(w, req)
201- // return
202- // }
203- // if err = t.Out(w, req, c); err != nil {
204- // dns.HandleFailed(w, req)
205- // return
206- // }
207- // return
208- //}
209- //c := &dns.Client{Net: "tcp"}
210- //resp, _, err := c.Exchange(req, addr)
211- //if err != nil {
212- // dns.HandleFailed(w, req)
213- // return
214- //}
213+ func dnsRequestProxy (addr string , transport string , req * dns.Msg ) (* dns.Msg , error ) {
214+ c := & dns.Client {Net : transport }
215+ resp , _ , err := c .Exchange (req , addr )
216+ return resp , err
217+ }
215218
219+ func httpDNSRequestProxy (addr string , _ string , req * dns.Msg ) (* dns.Msg , error ) {
216220 httpreq , err := http .NewRequest (http .MethodGet , addr , nil )
217221 if err != nil {
218222 log .Errorln ("Error setting up request:" , err )
219- dns .HandleFailed (w , req )
220- return
223+ return nil , err
221224 }
222225
223226 qry := httpreq .URL .Query ()
@@ -233,9 +236,7 @@ func proxy(addr string, w dns.ResponseWriter, req *dns.Msg) {
233236
234237 httpresp , err := http .DefaultClient .Do (httpreq )
235238 if err != nil {
236- log .Errorln ("Error sending DNS response:" , err )
237- dns .HandleFailed (w , req )
238- return
239+ return nil , err
239240 }
240241 defer httpresp .Body .Close () // nolint: errcheck
241242
@@ -244,9 +245,7 @@ func proxy(addr string, w dns.ResponseWriter, req *dns.Msg) {
244245 decoder := json .NewDecoder (httpresp .Body )
245246 err = decoder .Decode (& dnsResp )
246247 if err != nil {
247- log .Errorln ("Malformed JSON DNS response:" , err )
248- dns .HandleFailed (w , req )
249- return
248+ return nil , err
250249 }
251250
252251 // Parse the google Questions to DNS RRs
@@ -298,9 +297,76 @@ func proxy(addr string, w dns.ResponseWriter, req *dns.Msg) {
298297 Extra : extras ,
299298 }
300299
301- // Write the response
302- err = w .WriteMsg (& resp )
303- if err != nil {
304- log .Errorln ("Error writing DNS response:" , err )
300+ return & resp , nil
301+ }
302+
303+ func isSuccess (fallthroughStatuses map [int ]struct {}, resp * dns.Msg ) bool {
304+ if resp == nil {
305+ return false
306+ }
307+ _ , found := fallthroughStatuses [resp .Rcode ]
308+ return ! found
309+ }
310+
311+ func continueFallthrough (noFallthrough []string , req * dns.Msg ) bool {
312+ for _ , f := range noFallthrough {
313+ if f == "" {
314+ continue
315+ }
316+ for _ , q := range req .Question {
317+ if strings .HasSuffix (q .Name , f ) {
318+ return false
319+ }
320+ }
305321 }
322+ return true
323+ }
324+
325+ type proxyFunc func () (* dns.Msg , error )
326+
327+ func proxy (addr string , prefixServer string , suffixServer string , fallthroughStatuses map [int ]struct {},
328+ noFallthrough []string , w dns.ResponseWriter , req * dns.Msg ) {
329+
330+ qryCanFallthrough := continueFallthrough (noFallthrough , req )
331+
332+ transport := "udp"
333+ if _ , ok := w .RemoteAddr ().(* net.TCPAddr ); ok {
334+ transport = "tcp"
335+ }
336+
337+ proxyFuncs := []proxyFunc {}
338+
339+ // If prefix server set, try prefix server...
340+ if prefixServer != "" {
341+ proxyFuncs = append (proxyFuncs , func () (* dns.Msg , error ) { return dnsRequestProxy (prefixServer , transport , req ) })
342+
343+ }
344+
345+ proxyFuncs = append (proxyFuncs , func () (* dns.Msg , error ) { return httpDNSRequestProxy (addr , transport , req ) })
346+
347+ // If prefix server set, try prefix server...
348+ if suffixServer != "" {
349+ proxyFuncs = append (proxyFuncs , func () (* dns.Msg , error ) { return dnsRequestProxy (suffixServer , transport , req ) })
350+
351+ }
352+
353+ for _ , proxyFunc := range proxyFuncs {
354+ resp , err := proxyFunc ()
355+ if err == nil && (isSuccess (fallthroughStatuses , resp ) || ! qryCanFallthrough ) {
356+ // Write the response
357+ err = w .WriteMsg (resp )
358+ if err != nil {
359+ log .Errorln ("Error writing DNS response:" , err )
360+ dns .HandleFailed (w , req )
361+ }
362+ return
363+ }
364+
365+ if ! qryCanFallthrough {
366+ dns .HandleFailed (w , req )
367+ return
368+ }
369+ }
370+
371+ dns .HandleFailed (w , req )
306372}
0 commit comments