Skip to content

Commit f14fd1c

Browse files
authored
feat(dns): add parallel query (#5239)
1 parent cd51f57 commit f14fd1c

14 files changed

Lines changed: 422 additions & 121 deletions

app/dns/cache_controller.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ func (c *CacheController) migrate() {
194194
return
195195
}
196196

197-
errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items.")
197+
errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items")
198198

199199
batch := make([]migrationEntry, 0, migrationBatchSize)
200200
for domain, recD := range dirtyips {
@@ -214,7 +214,7 @@ func (c *CacheController) migrate() {
214214
c.dirtyips = nil
215215
c.Unlock()
216216

217-
errors.LogDebug(context.Background(), c.name, " cache migration completed.")
217+
errors.LogDebug(context.Background(), c.name, " cache migration completed")
218218
}
219219

220220
func (c *CacheController) flush(batch []migrationEntry) {

app/dns/config.pb.go

Lines changed: 86 additions & 65 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

app/dns/config.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ message NameServer {
3737
bool finalQuery = 12;
3838
repeated xray.app.router.GeoIP unexpected_geoip = 13;
3939
bool actUnprior = 14;
40+
uint32 policyID = 17;
4041
}
4142

4243
enum DomainMatchingType {
@@ -89,4 +90,6 @@ message Config {
8990

9091
bool disableFallback = 10;
9192
bool disableFallbackIfMatch = 11;
93+
94+
bool enableParallelQuery = 14;
9295
}

app/dns/dns.go

Lines changed: 215 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type DNS struct {
2525
sync.Mutex
2626
disableFallback bool
2727
disableFallbackIfMatch bool
28+
enableParallelQuery bool
2829
ipOption *dns.IPOption
2930
hosts *StaticHosts
3031
clients []*Client
@@ -157,6 +158,7 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
157158
matcherInfos: matcherInfos,
158159
disableFallback: config.DisableFallback,
159160
disableFallbackIfMatch: config.DisableFallbackIfMatch,
161+
enableParallelQuery: config.EnableParallelQuery,
160162
checkSystem: checkSystem,
161163
}, nil
162164
}
@@ -235,45 +237,11 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
235237
}
236238

237239
// Name servers lookup
238-
var errs []error
239-
for _, client := range s.sortClients(domain) {
240-
if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
241-
errors.LogDebug(s.ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
242-
continue
243-
}
244-
245-
ips, ttl, err := client.QueryIP(s.ctx, domain, option)
246-
247-
if len(ips) > 0 {
248-
if ttl == 0 {
249-
ttl = 1
250-
}
251-
return ips, ttl, nil
252-
}
253-
254-
errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name())
255-
if err == nil {
256-
err = dns.ErrEmptyResponse
257-
}
258-
errs = append(errs, err)
259-
260-
if client.IsFinalQuery() {
261-
break
262-
}
263-
}
264-
265-
if len(errs) > 0 {
266-
allErrs := errors.Combine(errs...)
267-
err0 := errs[0]
268-
if errors.AllEqual(err0, allErrs) {
269-
if go_errors.Is(err0, dns.ErrEmptyResponse) {
270-
return nil, 0, dns.ErrEmptyResponse
271-
}
272-
return nil, 0, errors.New("returning nil for domain ", domain).Base(err0)
273-
}
274-
return nil, 0, errors.New("returning nil for domain ", domain).Base(allErrs)
240+
if s.enableParallelQuery {
241+
return s.parallelQuery(domain, option)
242+
} else {
243+
return s.serialQuery(domain, option)
275244
}
276-
return nil, 0, dns.ErrEmptyResponse
277245
}
278246

279247
func (s *DNS) sortClients(domain string) []*Client {
@@ -300,6 +268,9 @@ func (s *DNS) sortClients(domain string) []*Client {
300268
clients = append(clients, client)
301269
clientNames = append(clientNames, client.Name())
302270
hasMatch = true
271+
if client.finalQuery {
272+
return clients
273+
}
303274
}
304275

305276
if !(s.disableFallback || s.disableFallbackIfMatch && hasMatch) {
@@ -311,6 +282,9 @@ func (s *DNS) sortClients(domain string) []*Client {
311282
clientUsed[idx] = true
312283
clients = append(clients, client)
313284
clientNames = append(clientNames, client.Name())
285+
if client.finalQuery {
286+
return clients
287+
}
314288
}
315289
}
316290

@@ -322,14 +296,214 @@ func (s *DNS) sortClients(domain string) []*Client {
322296
}
323297

324298
if len(clients) == 0 {
325-
clients = append(clients, s.clients[0])
326-
clientNames = append(clientNames, s.clients[0].Name())
327-
errors.LogDebug(s.ctx, "domain ", domain, " will use the first DNS: ", clientNames)
299+
if len(s.clients) > 0 {
300+
clients = append(clients, s.clients[0])
301+
clientNames = append(clientNames, s.clients[0].Name())
302+
errors.LogWarning(s.ctx, "domain ", domain, " will use the first DNS: ", clientNames)
303+
} else {
304+
errors.LogError(s.ctx, "no DNS clients available for domain ", domain, " and no default clients configured")
305+
}
328306
}
329307

330308
return clients
331309
}
332310

311+
func mergeQueryErrors(domain string, errs []error) error {
312+
if len(errs) == 0 {
313+
return dns.ErrEmptyResponse
314+
}
315+
316+
var noRNF error
317+
for _, err := range errs {
318+
if go_errors.Is(err, errRecordNotFound) {
319+
continue // server no response, ignore
320+
} else if noRNF == nil {
321+
noRNF = err
322+
} else if !go_errors.Is(err, noRNF) {
323+
return errors.New("returning nil for domain ", domain).Base(errors.Combine(errs...))
324+
}
325+
}
326+
if go_errors.Is(noRNF, dns.ErrEmptyResponse) {
327+
return dns.ErrEmptyResponse
328+
}
329+
if noRNF == nil {
330+
noRNF = errRecordNotFound
331+
}
332+
return errors.New("returning nil for domain ", domain).Base(noRNF)
333+
}
334+
335+
func (s *DNS) serialQuery(domain string, option dns.IPOption) ([]net.IP, uint32, error) {
336+
var errs []error
337+
for _, client := range s.sortClients(domain) {
338+
if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
339+
errors.LogDebug(s.ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
340+
continue
341+
}
342+
343+
ips, ttl, err := client.QueryIP(s.ctx, domain, option)
344+
345+
if len(ips) > 0 {
346+
return ips, ttl, nil
347+
}
348+
349+
errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name(), " in serial query mode")
350+
if err == nil {
351+
err = dns.ErrEmptyResponse
352+
}
353+
errs = append(errs, err)
354+
}
355+
return nil, 0, mergeQueryErrors(domain, errs)
356+
}
357+
358+
func (s *DNS) parallelQuery(domain string, option dns.IPOption) ([]net.IP, uint32, error) {
359+
var errs []error
360+
clients := s.sortClients(domain)
361+
362+
resultsChan := asyncQueryAll(domain, option, clients, s.ctx)
363+
364+
groups, groupOf := makeGroups( /*s.ctx,*/ clients)
365+
results := make([]*queryResult, len(clients))
366+
pending := make([]int, len(groups))
367+
for gi, g := range groups {
368+
pending[gi] = g.end - g.start + 1
369+
}
370+
371+
nextGroup := 0
372+
for range clients {
373+
result := <-resultsChan
374+
results[result.index] = &result
375+
376+
gi := groupOf[result.index]
377+
pending[gi]--
378+
379+
for nextGroup < len(groups) {
380+
g := groups[nextGroup]
381+
382+
// group race, minimum rtt -> return
383+
for j := g.start; j <= g.end; j++ {
384+
r := results[j]
385+
if r != nil && r.err == nil && len(r.ips) > 0 {
386+
return r.ips, r.ttl, nil
387+
}
388+
}
389+
390+
// current group is incomplete and no one success -> continue pending
391+
if pending[nextGroup] > 0 {
392+
break
393+
}
394+
395+
// all failed -> log and continue next group
396+
for j := g.start; j <= g.end; j++ {
397+
r := results[j]
398+
e := r.err
399+
if e == nil {
400+
e = dns.ErrEmptyResponse
401+
}
402+
errors.LogInfoInner(s.ctx, e, "failed to lookup ip for domain ", domain, " at server ", clients[j].Name(), " in parallel query mode")
403+
errs = append(errs, e)
404+
}
405+
nextGroup++
406+
}
407+
}
408+
409+
return nil, 0, mergeQueryErrors(domain, errs)
410+
}
411+
412+
type queryResult struct {
413+
ips []net.IP
414+
ttl uint32
415+
err error
416+
index int
417+
}
418+
419+
func asyncQueryAll(domain string, option dns.IPOption, clients []*Client, ctx context.Context) chan queryResult {
420+
if len(clients) == 0 {
421+
ch := make(chan queryResult)
422+
close(ch)
423+
return ch
424+
}
425+
426+
ch := make(chan queryResult, len(clients))
427+
for i, client := range clients {
428+
if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
429+
errors.LogDebug(ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
430+
ch <- queryResult{err: dns.ErrEmptyResponse, index: i}
431+
continue
432+
}
433+
434+
go func(i int, c *Client) {
435+
qctx := ctx
436+
if !c.server.IsDisableCache() {
437+
nctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), c.timeoutMs*2)
438+
qctx = nctx
439+
defer cancel()
440+
}
441+
ips, ttl, err := c.QueryIP(qctx, domain, option)
442+
ch <- queryResult{ips: ips, ttl: ttl, err: err, index: i}
443+
}(i, client)
444+
}
445+
return ch
446+
}
447+
448+
type group struct{ start, end int }
449+
450+
// merge only adjacent and rule-equivalent Client into a single group
451+
func makeGroups( /*ctx context.Context,*/ clients []*Client) ([]group, []int) {
452+
n := len(clients)
453+
if n == 0 {
454+
return nil, nil
455+
}
456+
groups := make([]group, 0, n)
457+
groupOf := make([]int, n)
458+
459+
s, e := 0, 0
460+
for i := 1; i < n; i++ {
461+
if clients[i-1].policyID == clients[i].policyID {
462+
e = i
463+
} else {
464+
for k := s; k <= e; k++ {
465+
groupOf[k] = len(groups)
466+
}
467+
groups = append(groups, group{start: s, end: e})
468+
s, e = i, i
469+
}
470+
}
471+
for k := s; k <= e; k++ {
472+
groupOf[k] = len(groups)
473+
}
474+
groups = append(groups, group{start: s, end: e})
475+
476+
// var b strings.Builder
477+
// b.WriteString("dns grouping: total clients=")
478+
// b.WriteString(strconv.Itoa(n))
479+
// b.WriteString(", groups=")
480+
// b.WriteString(strconv.Itoa(len(groups)))
481+
482+
// for gi, g := range groups {
483+
// b.WriteString("\n [")
484+
// b.WriteString(strconv.Itoa(g.start))
485+
// b.WriteString("..")
486+
// b.WriteString(strconv.Itoa(g.end))
487+
// b.WriteString("] gid=")
488+
// b.WriteString(strconv.Itoa(gi))
489+
// b.WriteString(" pid=")
490+
// b.WriteString(strconv.FormatUint(uint64(clients[g.start].policyID), 10))
491+
// b.WriteString(" members: ")
492+
493+
// for i := g.start; i <= g.end; i++ {
494+
// if i > g.start {
495+
// b.WriteString(", ")
496+
// }
497+
// b.WriteString(strconv.Itoa(i))
498+
// b.WriteByte(':')
499+
// b.WriteString(clients[i].Name())
500+
// }
501+
// }
502+
// errors.LogDebug(ctx, b.String())
503+
504+
return groups, groupOf
505+
}
506+
333507
func init() {
334508
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
335509
return New(ctx, config.(*Config))

app/dns/nameserver.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ import (
2020
type Server interface {
2121
// Name of the Client.
2222
Name() string
23+
24+
IsDisableCache() bool
25+
2326
// QueryIP sends IP queries to its configured server.
2427
QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error)
2528
}
@@ -38,6 +41,7 @@ type Client struct {
3841
finalQuery bool
3942
ipOption *dns.IPOption
4043
checkSystem bool
44+
policyID uint32
4145
}
4246

4347
// NewServer creates a name server object according to the network destination url.
@@ -199,6 +203,7 @@ func NewClient(
199203
client.finalQuery = ns.FinalQuery
200204
client.ipOption = &ipOption
201205
client.checkSystem = checkSystem
206+
client.policyID = ns.PolicyID
202207
return nil
203208
})
204209
return client, err
@@ -209,10 +214,6 @@ func (c *Client) Name() string {
209214
return c.server.Name()
210215
}
211216

212-
func (c *Client) IsFinalQuery() bool {
213-
return c.finalQuery
214-
}
215-
216217
// QueryIP sends DNS query to the name server with the client's IP.
217218
func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error) {
218219
if c.checkSystem {

app/dns/nameserver_cached.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.
2828
ips, ttl, err := merge(option, rec.A, rec.AAAA)
2929
if !go_errors.Is(err, errRecordNotFound) {
3030
if ttl > 0 {
31-
// errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
31+
errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
3232
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
3333
return ips, uint32(ttl), err
3434
}
3535
if cache.serveStale && (cache.serveExpiredTTL == 0 || cache.serveExpiredTTL < ttl) {
36-
// errors.LogDebugInner(ctx, err, cache.name, " cache OPTIMISTE ", fqdn, " -> ", ips)
36+
errors.LogDebugInner(ctx, err, cache.name, " cache OPTIMISTE ", fqdn, " -> ", ips)
3737
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheOptimiste, Elapsed: 0, Error: err})
3838
go pull(ctx, s, fqdn, option)
3939
return ips, 1, err

0 commit comments

Comments
 (0)