-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsingleflight.go
More file actions
169 lines (144 loc) · 3.6 KB
/
singleflight.go
File metadata and controls
169 lines (144 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
package resolver
import (
"context"
"encoding/base64"
"time"
"github.com/miekg/dns"
"golang.org/x/sync/singleflight"
"darvaza.org/core"
"darvaza.org/resolver/pkg/client"
"darvaza.org/resolver/pkg/errors"
"darvaza.org/resolver/pkg/exdns"
)
var (
_ Lookuper = (*SingleFlight)(nil)
_ Exchanger = (*SingleFlight)(nil)
)
// SingleFlightHasher is a function that generates the
// caching key for a request.
type SingleFlightHasher func(context.Context, *dns.Msg) (string, error)
// SingleFlight is an [Exchanger]/[Lookuper] that holds/caches
// identical queries before passing them over to another [Exchanger].
type SingleFlight struct {
e Exchanger
g singleflight.Group
exp time.Duration
h SingleFlightHasher
}
// Lookup implements the [Lookuper] interface holding/caching
// identical queries.
func (sf *SingleFlight) Lookup(ctx context.Context, qName string, qType uint16) (*dns.Msg, error) {
if ctx == nil {
return nil, errors.ErrBadRequest()
}
req := exdns.NewRequestFromParts(qName, dns.ClassINET, qType)
return sf.Exchange(ctx, req)
}
// Exchange implements the [Exchanger] interface holding/caching
// identical queries.
func (sf *SingleFlight) Exchange(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
var original *dns.Msg
if ctx == nil || req == nil {
return nil, errors.ErrBadRequest()
}
switch len(req.Question) {
case 0:
// nothing to answer
msg := new(dns.Msg)
msg.SetReply(req)
return msg, nil
case 1:
if req.Id == 0 {
// make sure it comes with an ID
req.Id = dns.Id()
}
default:
// shrink
original = req
req = req.Copy()
req.Id = dns.Id()
req.Question = []dns.Question{
req.Question[0],
}
}
resp, err := sf.doExchange(ctx, req)
return exdns.RestoreReturn(original, resp, err)
}
func (sf *SingleFlight) doExchange(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
//
key, err := sf.h(ctx, req)
if err != nil {
return nil, err
}
v, err, _ := sf.g.Do(key, func() (any, error) {
resp, err := sf.e.Exchange(ctx, req)
sf.deferredExpiration(key)
return resp, err
})
resp, ok := v.(*dns.Msg)
switch {
case ok:
// pass through
return resp, err
case err == nil:
// this can't happen
q := msgQuestion(req)
return nil, errors.ErrInternalError(q.Name, "singleflight")
default:
// failed
return nil, err
}
}
func (sf *SingleFlight) deferredExpiration(key string) {
switch {
case sf.exp > 0:
// deferred expiration
go func(key string) {
<-time.After(sf.exp)
sf.g.Forget(key)
}(key)
default:
// immediate
sf.g.Forget(key)
}
}
// NewSingleFlight creates a [Exchanger] wrapper holding/caching identical
// requests for up to the given time, using the given function to produce
// the keys or base64 packed if no hasher is provided.
// use negative exp to indicate immediate as zero will be replaced
// with the default of 1s.
func NewSingleFlight(next Exchanger, exp time.Duration,
hasher SingleFlightHasher) (*SingleFlight, error) {
//
if next == nil || exp < 0 {
return nil, core.ErrInvalid
}
if exp == 0 {
exp = client.DefaultSingleFlightExpiration
}
if hasher == nil {
hasher = DefaultSingleFlightHasher
}
sf := &SingleFlight{
e: next,
exp: exp,
h: hasher,
}
return sf, nil
}
// DefaultSingleFlightHasher returns the base64 encoded
// representation of the packed request, ignoring the ID.
func DefaultSingleFlightHasher(_ context.Context, req *dns.Msg) (string, error) {
if req == nil {
return "", core.ErrInvalid
}
id := req.Id
req.Id = 0
b, err := req.Pack()
req.Id = id
if err != nil {
return "", errors.ErrBadRequest()
}
s := base64.RawStdEncoding.EncodeToString(b)
return s, nil
}