-
Notifications
You must be signed in to change notification settings - Fork 209
Expand file tree
/
Copy pathextensions.go
More file actions
295 lines (259 loc) · 8.45 KB
/
extensions.go
File metadata and controls
295 lines (259 loc) · 8.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
package pubsub
import (
"errors"
"iter"
"github.com/libp2p/go-libp2p-pubsub/partialmessages"
pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb"
"github.com/libp2p/go-libp2p/core/peer"
)
type PeerExtensions struct {
TestExtension bool
PartialMessages bool
}
type TestExtensionConfig struct {
OnReceiveTestExtension func(from peer.ID)
}
func WithTestExtension(c TestExtensionConfig) Option {
return func(ps *PubSub) error {
if rt, ok := ps.rt.(*GossipSubRouter); ok {
rt.extensions.testExtension = &testExtension{
sendRPC: rt.extensions.sendRPC,
onReceiveTestExtension: c.OnReceiveTestExtension,
}
rt.extensions.myExtensions.TestExtension = true
}
return nil
}
}
func hasPeerExtensions(rpc *RPC) bool {
if rpc != nil && rpc.Control != nil && rpc.Control.Extensions != nil {
return true
}
return false
}
func peerExtensionsFromRPC(rpc *RPC) PeerExtensions {
out := PeerExtensions{}
if hasPeerExtensions(rpc) {
out.TestExtension = rpc.Control.Extensions.GetTestExtension()
out.PartialMessages = rpc.Control.Extensions.GetPartialMessages()
}
return out
}
func (pe *PeerExtensions) ExtendRPC(rpc *RPC) *RPC {
if pe.TestExtension {
if rpc.Control == nil {
rpc.Control = &pubsub_pb.ControlMessage{}
}
if rpc.Control.Extensions == nil {
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
}
rpc.Control.Extensions.TestExtension = &pe.TestExtension
}
if pe.PartialMessages {
if rpc.Control == nil {
rpc.Control = &pubsub_pb.ControlMessage{}
}
if rpc.Control.Extensions == nil {
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
}
rpc.Control.Extensions.PartialMessages = &pe.PartialMessages
}
return rpc
}
// Using an interface type to avoid bubbling up PartialMessage's generics up to
// pubsub.
//
// Purposely not trying to make a generic extension interface as there is only
// one real consumer (partial messages). This may change in the future.
type partialMessageInterface interface {
RemovePeer(peer.ID)
HandleRPC(from peer.ID, rpc *pubsub_pb.PartialMessagesExtension) error
Heartbeat()
EmitGossip(topic string, peers []peer.ID)
}
type extensionsState struct {
myExtensions PeerExtensions
peerExtensions map[peer.ID]PeerExtensions // peer's extensions
sentExtensions map[peer.ID]struct{}
reportMisbehavior func(peer.ID)
sendRPC func(p peer.ID, r *RPC, urgent bool)
testExtension *testExtension
partialMessagesExtension partialMessageInterface
}
func newExtensionsState(myExtensions PeerExtensions, reportMisbehavior func(peer.ID), sendRPC func(peer.ID, *RPC, bool)) *extensionsState {
return &extensionsState{
myExtensions: myExtensions,
peerExtensions: make(map[peer.ID]PeerExtensions),
sentExtensions: make(map[peer.ID]struct{}),
reportMisbehavior: reportMisbehavior,
sendRPC: sendRPC,
testExtension: nil,
}
}
func (es *extensionsState) HandleRPC(rpc *RPC) error {
if _, ok := es.peerExtensions[rpc.from]; !ok {
// We know this is the first message because we didn't have extensions
// for this peer, and we always set extensions on the first rpc.
es.peerExtensions[rpc.from] = peerExtensionsFromRPC(rpc)
if _, ok := es.sentExtensions[rpc.from]; ok {
// We just finished both sending and receiving the extensions
// control message.
es.extensionsAddPeer(rpc.from)
}
} else {
// We already have an extension for this peer. If they send us another
// extensions control message, that is a protocol error. We should
// down score them because they are misbehaving.
if hasPeerExtensions(rpc) {
es.reportMisbehavior(rpc.from)
}
}
return es.extensionsHandleRPC(rpc)
}
func (es *extensionsState) AddPeer(id peer.ID, helloPacket *RPC) *RPC {
// Send our extensions as the first message.
helloPacket = es.myExtensions.ExtendRPC(helloPacket)
es.sentExtensions[id] = struct{}{}
if _, ok := es.peerExtensions[id]; ok {
// We've just finished sending and receiving the extensions control
// message.
es.extensionsAddPeer(id)
}
return helloPacket
}
func (es *extensionsState) RemovePeer(id peer.ID) {
_, recvdExt := es.peerExtensions[id]
_, sentExt := es.sentExtensions[id]
if recvdExt && sentExt {
// Add peer was previously called, so we need to call remove peer
es.extensionsRemovePeer(id)
}
delete(es.peerExtensions, id)
if len(es.peerExtensions) == 0 {
es.peerExtensions = make(map[peer.ID]PeerExtensions)
}
delete(es.sentExtensions, id)
if len(es.sentExtensions) == 0 {
es.sentExtensions = make(map[peer.ID]struct{})
}
}
// extensionsAddPeer is only called once we've both sent and received the
// extensions control message.
func (es *extensionsState) extensionsAddPeer(id peer.ID) {
if es.myExtensions.TestExtension && es.peerExtensions[id].TestExtension {
es.testExtension.AddPeer(id)
}
}
// extensionsRemovePeer is always called after extensionsAddPeer.
func (es *extensionsState) extensionsRemovePeer(id peer.ID) {
if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages {
es.partialMessagesExtension.RemovePeer(id)
}
}
func (es *extensionsState) extensionsHandleRPC(rpc *RPC) error {
if es.myExtensions.TestExtension && es.peerExtensions[rpc.from].TestExtension {
es.testExtension.HandleRPC(rpc.from, rpc.TestExtension)
}
if es.myExtensions.PartialMessages && es.peerExtensions[rpc.from].PartialMessages && rpc.Partial != nil {
err := es.partialMessagesExtension.HandleRPC(rpc.from, rpc.Partial)
if err != nil {
return err
}
}
return nil
}
func (es *extensionsState) Heartbeat() {
if es.myExtensions.PartialMessages {
es.partialMessagesExtension.Heartbeat()
}
}
func WithPartialMessagesExtension[PeerState any](pm *partialmessages.PartialMessagesExtension[PeerState]) Option {
return func(ps *PubSub) error {
gs, ok := ps.rt.(*GossipSubRouter)
if !ok {
return errors.New("pubsub router is not gossipsub")
}
err := pm.Init(partialMessageRouter{gs})
if err != nil {
return err
}
gs.extensions.myExtensions.PartialMessages = true
gs.extensions.partialMessagesExtension = pm
return nil
}
}
// PublishPartial uses the given PubSub instance to publish partial messages.
// This is a standalone function rather a method on PubSub due to the generic
// type parameter.
func PublishPartial[PeerState any](ps *PubSub, topic string, groupID []byte, publishActionsFn partialmessages.PublishActionsFn[PeerState]) error {
resp := make(chan error, 1)
select {
case <-ps.ctx.Done():
return ps.ctx.Err()
case ps.eval <- func() {
defer close(resp)
rt, ok := ps.rt.(*GossipSubRouter)
if !ok {
resp <- errors.New("partial publishing is only supported by the GossipSub router")
return
}
if rt.extensions.partialMessagesExtension == nil {
resp <- errors.New("partial publishing is not enabled")
return
}
pme, ok := rt.extensions.partialMessagesExtension.(*partialmessages.PartialMessagesExtension[PeerState])
if !ok {
resp <- errors.New("incompatible partial messages extension type")
return
}
resp <- pme.PublishPartial(topic, groupID, publishActionsFn)
}:
}
select {
case <-ps.ctx.Done():
return ps.ctx.Err()
case r := <-resp:
return r
}
}
type partialMessageRouter struct {
gs *GossipSubRouter
}
// PeerRequestsPartial returns true if a peer requested partial messages on this topic.
//
// It does not check if we support partial messages on the topic, because we may
// not be subscribed to that topic and thus not have that information.
// Callers should not use this if they don't support partial messages on this topic.
func (r partialMessageRouter) PeerRequestsPartial(peer peer.ID, topic string) bool {
return r.gs.peerRequestsPartial(peer, topic)
}
// MeshPeers implements partialmessages.Router.
func (r partialMessageRouter) MeshPeers(topic string) iter.Seq[peer.ID] {
return func(yield func(peer.ID) bool) {
peerSet, ok := r.gs.mesh[topic]
if !ok {
// Possibly a fanout topic
peerSet, ok = r.gs.fanout[topic]
if !ok {
return
}
}
for peer := range peerSet {
if r.gs.extensions.peerExtensions[peer].PartialMessages &&
(r.gs.iRequestPartial(topic) && r.gs.peerSupportsSendingPartial(peer, topic)) || (r.gs.iSupportSendingPartial(topic) && r.gs.peerRequestsPartial(peer, topic)) {
if !yield(peer) {
return
}
}
}
}
}
// SendRPC implements partialmessages.Router.
func (r partialMessageRouter) SendRPC(p peer.ID, rpc *pubsub_pb.PartialMessagesExtension, urgent bool) {
r.gs.sendRPC(p, &RPC{
RPC: pubsub_pb.RPC{
Partial: rpc,
},
}, urgent)
}
var _ partialmessages.Router = partialMessageRouter{}