@@ -25,6 +25,7 @@ type DNS struct {
2525 sync.Mutex
2626 disableFallback bool
2727 disableFallbackIfMatch bool
28+ enableParallelQuery bool
2829 ipOption * dns.IPOption
2930 hosts * StaticHosts
3031 clients []* Client
@@ -157,6 +158,7 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
157158 matcherInfos : matcherInfos ,
158159 disableFallback : config .DisableFallback ,
159160 disableFallbackIfMatch : config .DisableFallbackIfMatch ,
161+ enableParallelQuery : config .EnableParallelQuery ,
160162 checkSystem : checkSystem ,
161163 }, nil
162164}
@@ -235,45 +237,11 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
235237 }
236238
237239 // Name servers lookup
238- var errs []error
239- for _ , client := range s .sortClients (domain ) {
240- if ! option .FakeEnable && strings .EqualFold (client .Name (), "FakeDNS" ) {
241- errors .LogDebug (s .ctx , "skip DNS resolution for domain " , domain , " at server " , client .Name ())
242- continue
243- }
244-
245- ips , ttl , err := client .QueryIP (s .ctx , domain , option )
246-
247- if len (ips ) > 0 {
248- if ttl == 0 {
249- ttl = 1
250- }
251- return ips , ttl , nil
252- }
253-
254- errors .LogInfoInner (s .ctx , err , "failed to lookup ip for domain " , domain , " at server " , client .Name ())
255- if err == nil {
256- err = dns .ErrEmptyResponse
257- }
258- errs = append (errs , err )
259-
260- if client .IsFinalQuery () {
261- break
262- }
263- }
264-
265- if len (errs ) > 0 {
266- allErrs := errors .Combine (errs ... )
267- err0 := errs [0 ]
268- if errors .AllEqual (err0 , allErrs ) {
269- if go_errors .Is (err0 , dns .ErrEmptyResponse ) {
270- return nil , 0 , dns .ErrEmptyResponse
271- }
272- return nil , 0 , errors .New ("returning nil for domain " , domain ).Base (err0 )
273- }
274- return nil , 0 , errors .New ("returning nil for domain " , domain ).Base (allErrs )
240+ if s .enableParallelQuery {
241+ return s .parallelQuery (domain , option )
242+ } else {
243+ return s .serialQuery (domain , option )
275244 }
276- return nil , 0 , dns .ErrEmptyResponse
277245}
278246
279247func (s * DNS ) sortClients (domain string ) []* Client {
@@ -300,6 +268,9 @@ func (s *DNS) sortClients(domain string) []*Client {
300268 clients = append (clients , client )
301269 clientNames = append (clientNames , client .Name ())
302270 hasMatch = true
271+ if client .finalQuery {
272+ return clients
273+ }
303274 }
304275
305276 if ! (s .disableFallback || s .disableFallbackIfMatch && hasMatch ) {
@@ -311,6 +282,9 @@ func (s *DNS) sortClients(domain string) []*Client {
311282 clientUsed [idx ] = true
312283 clients = append (clients , client )
313284 clientNames = append (clientNames , client .Name ())
285+ if client .finalQuery {
286+ return clients
287+ }
314288 }
315289 }
316290
@@ -322,14 +296,214 @@ func (s *DNS) sortClients(domain string) []*Client {
322296 }
323297
324298 if len (clients ) == 0 {
325- clients = append (clients , s .clients [0 ])
326- clientNames = append (clientNames , s .clients [0 ].Name ())
327- errors .LogDebug (s .ctx , "domain " , domain , " will use the first DNS: " , clientNames )
299+ if len (s .clients ) > 0 {
300+ clients = append (clients , s .clients [0 ])
301+ clientNames = append (clientNames , s .clients [0 ].Name ())
302+ errors .LogWarning (s .ctx , "domain " , domain , " will use the first DNS: " , clientNames )
303+ } else {
304+ errors .LogError (s .ctx , "no DNS clients available for domain " , domain , " and no default clients configured" )
305+ }
328306 }
329307
330308 return clients
331309}
332310
311+ func mergeQueryErrors (domain string , errs []error ) error {
312+ if len (errs ) == 0 {
313+ return dns .ErrEmptyResponse
314+ }
315+
316+ var noRNF error
317+ for _ , err := range errs {
318+ if go_errors .Is (err , errRecordNotFound ) {
319+ continue // server no response, ignore
320+ } else if noRNF == nil {
321+ noRNF = err
322+ } else if ! go_errors .Is (err , noRNF ) {
323+ return errors .New ("returning nil for domain " , domain ).Base (errors .Combine (errs ... ))
324+ }
325+ }
326+ if go_errors .Is (noRNF , dns .ErrEmptyResponse ) {
327+ return dns .ErrEmptyResponse
328+ }
329+ if noRNF == nil {
330+ noRNF = errRecordNotFound
331+ }
332+ return errors .New ("returning nil for domain " , domain ).Base (noRNF )
333+ }
334+
335+ func (s * DNS ) serialQuery (domain string , option dns.IPOption ) ([]net.IP , uint32 , error ) {
336+ var errs []error
337+ for _ , client := range s .sortClients (domain ) {
338+ if ! option .FakeEnable && strings .EqualFold (client .Name (), "FakeDNS" ) {
339+ errors .LogDebug (s .ctx , "skip DNS resolution for domain " , domain , " at server " , client .Name ())
340+ continue
341+ }
342+
343+ ips , ttl , err := client .QueryIP (s .ctx , domain , option )
344+
345+ if len (ips ) > 0 {
346+ return ips , ttl , nil
347+ }
348+
349+ errors .LogInfoInner (s .ctx , err , "failed to lookup ip for domain " , domain , " at server " , client .Name (), " in serial query mode" )
350+ if err == nil {
351+ err = dns .ErrEmptyResponse
352+ }
353+ errs = append (errs , err )
354+ }
355+ return nil , 0 , mergeQueryErrors (domain , errs )
356+ }
357+
358+ func (s * DNS ) parallelQuery (domain string , option dns.IPOption ) ([]net.IP , uint32 , error ) {
359+ var errs []error
360+ clients := s .sortClients (domain )
361+
362+ resultsChan := asyncQueryAll (domain , option , clients , s .ctx )
363+
364+ groups , groupOf := makeGroups ( /*s.ctx,*/ clients )
365+ results := make ([]* queryResult , len (clients ))
366+ pending := make ([]int , len (groups ))
367+ for gi , g := range groups {
368+ pending [gi ] = g .end - g .start + 1
369+ }
370+
371+ nextGroup := 0
372+ for range clients {
373+ result := <- resultsChan
374+ results [result .index ] = & result
375+
376+ gi := groupOf [result .index ]
377+ pending [gi ]--
378+
379+ for nextGroup < len (groups ) {
380+ g := groups [nextGroup ]
381+
382+ // group race, minimum rtt -> return
383+ for j := g .start ; j <= g .end ; j ++ {
384+ r := results [j ]
385+ if r != nil && r .err == nil && len (r .ips ) > 0 {
386+ return r .ips , r .ttl , nil
387+ }
388+ }
389+
390+ // current group is incomplete and no one success -> continue pending
391+ if pending [nextGroup ] > 0 {
392+ break
393+ }
394+
395+ // all failed -> log and continue next group
396+ for j := g .start ; j <= g .end ; j ++ {
397+ r := results [j ]
398+ e := r .err
399+ if e == nil {
400+ e = dns .ErrEmptyResponse
401+ }
402+ errors .LogInfoInner (s .ctx , e , "failed to lookup ip for domain " , domain , " at server " , clients [j ].Name (), " in parallel query mode" )
403+ errs = append (errs , e )
404+ }
405+ nextGroup ++
406+ }
407+ }
408+
409+ return nil , 0 , mergeQueryErrors (domain , errs )
410+ }
411+
412+ type queryResult struct {
413+ ips []net.IP
414+ ttl uint32
415+ err error
416+ index int
417+ }
418+
419+ func asyncQueryAll (domain string , option dns.IPOption , clients []* Client , ctx context.Context ) chan queryResult {
420+ if len (clients ) == 0 {
421+ ch := make (chan queryResult )
422+ close (ch )
423+ return ch
424+ }
425+
426+ ch := make (chan queryResult , len (clients ))
427+ for i , client := range clients {
428+ if ! option .FakeEnable && strings .EqualFold (client .Name (), "FakeDNS" ) {
429+ errors .LogDebug (ctx , "skip DNS resolution for domain " , domain , " at server " , client .Name ())
430+ ch <- queryResult {err : dns .ErrEmptyResponse , index : i }
431+ continue
432+ }
433+
434+ go func (i int , c * Client ) {
435+ qctx := ctx
436+ if ! c .server .IsDisableCache () {
437+ nctx , cancel := context .WithTimeout (context .WithoutCancel (ctx ), c .timeoutMs * 2 )
438+ qctx = nctx
439+ defer cancel ()
440+ }
441+ ips , ttl , err := c .QueryIP (qctx , domain , option )
442+ ch <- queryResult {ips : ips , ttl : ttl , err : err , index : i }
443+ }(i , client )
444+ }
445+ return ch
446+ }
447+
448+ type group struct { start , end int }
449+
450+ // merge only adjacent and rule-equivalent Client into a single group
451+ func makeGroups ( /*ctx context.Context,*/ clients []* Client ) ([]group , []int ) {
452+ n := len (clients )
453+ if n == 0 {
454+ return nil , nil
455+ }
456+ groups := make ([]group , 0 , n )
457+ groupOf := make ([]int , n )
458+
459+ s , e := 0 , 0
460+ for i := 1 ; i < n ; i ++ {
461+ if clients [i - 1 ].policyID == clients [i ].policyID {
462+ e = i
463+ } else {
464+ for k := s ; k <= e ; k ++ {
465+ groupOf [k ] = len (groups )
466+ }
467+ groups = append (groups , group {start : s , end : e })
468+ s , e = i , i
469+ }
470+ }
471+ for k := s ; k <= e ; k ++ {
472+ groupOf [k ] = len (groups )
473+ }
474+ groups = append (groups , group {start : s , end : e })
475+
476+ // var b strings.Builder
477+ // b.WriteString("dns grouping: total clients=")
478+ // b.WriteString(strconv.Itoa(n))
479+ // b.WriteString(", groups=")
480+ // b.WriteString(strconv.Itoa(len(groups)))
481+
482+ // for gi, g := range groups {
483+ // b.WriteString("\n [")
484+ // b.WriteString(strconv.Itoa(g.start))
485+ // b.WriteString("..")
486+ // b.WriteString(strconv.Itoa(g.end))
487+ // b.WriteString("] gid=")
488+ // b.WriteString(strconv.Itoa(gi))
489+ // b.WriteString(" pid=")
490+ // b.WriteString(strconv.FormatUint(uint64(clients[g.start].policyID), 10))
491+ // b.WriteString(" members: ")
492+
493+ // for i := g.start; i <= g.end; i++ {
494+ // if i > g.start {
495+ // b.WriteString(", ")
496+ // }
497+ // b.WriteString(strconv.Itoa(i))
498+ // b.WriteByte(':')
499+ // b.WriteString(clients[i].Name())
500+ // }
501+ // }
502+ // errors.LogDebug(ctx, b.String())
503+
504+ return groups , groupOf
505+ }
506+
333507func init () {
334508 common .Must (common .RegisterConfig ((* Config )(nil ), func (ctx context.Context , config interface {}) (interface {}, error ) {
335509 return New (ctx , config .(* Config ))
0 commit comments