Skip to content

Commit 019b9d9

Browse files
authored
make sure to unset batch_key when connection is removed from batch (#48311)
1 parent 7ea40ff commit 019b9d9

File tree

3 files changed

+181
-97
lines changed

3 files changed

+181
-97
lines changed

src/connmgr/batch.rs

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*
22
* Copyright (C) 2020-2023 Fanout, Inc.
3-
* Copyright (C) 2023-2024 Fastly, Inc.
3+
* Copyright (C) 2023-2026 Fastly, Inc.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -28,13 +28,13 @@ pub struct BatchKey {
2828
nkey: usize,
2929
}
3030

31-
pub struct BatchGroup<'a, 'b> {
32-
addr: &'b [u8],
31+
pub struct BatchGroup<'a> {
32+
addr: &'a [u8],
3333
use_router: bool,
34-
ids: memorypool::ReusableVecHandle<'b, zhttppacket::Id<'a>>,
34+
removed: &'a [(usize, bool)],
3535
}
3636

37-
impl<'a> BatchGroup<'a, '_> {
37+
impl BatchGroup<'_> {
3838
pub fn addr(&self) -> &[u8] {
3939
self.addr
4040
}
@@ -43,9 +43,39 @@ impl<'a> BatchGroup<'a, '_> {
4343
self.use_router
4444
}
4545

46-
pub fn ids(&self) -> &[zhttppacket::Id<'a>] {
46+
/// Returns slice of (ckey, included).
47+
pub fn removed(&self) -> &[(usize, bool)] {
48+
self.removed
49+
}
50+
}
51+
52+
pub struct BatchGroupWithIds<'a, 'b> {
53+
inner: BatchGroup<'a>,
54+
ids: memorypool::ReusableVecHandle<'b, zhttppacket::Id<'b>>,
55+
}
56+
57+
impl<'a, 'b> BatchGroupWithIds<'a, 'b> {
58+
pub fn addr(&self) -> &[u8] {
59+
self.inner.addr()
60+
}
61+
62+
pub fn use_router(&self) -> bool {
63+
self.inner.use_router()
64+
}
65+
66+
/// Returns slice of (ckey, included).
67+
#[cfg(test)]
68+
pub fn removed(&self) -> &[(usize, bool)] {
69+
&self.inner.removed()
70+
}
71+
72+
pub fn ids(&self) -> &[zhttppacket::Id<'b>] {
4773
&self.ids
4874
}
75+
76+
pub fn discard_ids(self) -> BatchGroup<'a> {
77+
self.inner
78+
}
4979
}
5080

5181
struct AddrItem {
@@ -59,7 +89,7 @@ pub struct Batch {
5989
addrs: Vec<AddrItem>,
6090
addr_index: usize,
6191
group_ids: memorypool::ReusableVec,
62-
last_group_ckeys: Vec<usize>,
92+
group_removed: Vec<(usize, bool)>,
6393
}
6494

6595
impl Batch {
@@ -69,7 +99,7 @@ impl Batch {
6999
addrs: Vec::with_capacity(capacity),
70100
addr_index: 0,
71101
group_ids: memorypool::ReusableVec::new::<zhttppacket::Id>(capacity),
72-
last_group_ckeys: Vec::with_capacity(capacity),
102+
group_removed: Vec::with_capacity(capacity),
73103
}
74104
}
75105

@@ -147,13 +177,29 @@ impl Batch {
147177
self.nodes.remove(key.nkey);
148178
}
149179

150-
pub fn take_group<'a, 'b: 'a, F>(&'a mut self, get_id: F) -> Option<BatchGroup<'a, 'b>>
180+
/// Returns a set of IDs for connections in the batch that have the same
181+
/// peer. The caller can then easily send a single packet addressed to
182+
/// all of them. The caller should repeatedly call `take_group` until all
183+
/// the connections are drained from the batch. Returns None when there
184+
/// are no connections in the batch.
185+
///
186+
/// This method works by removing connections from the batch one at a
187+
/// time, and calling the `include` function for each one with its ckey.
188+
/// If `include` returns Some((ID, seq)), then it is included in the
189+
/// returned set of IDs. If it returns None, then the connection is
190+
/// excluded from the set.
191+
///
192+
/// If the batch has connections and `include` returns None for all of
193+
/// them, then this method will return an empty set of IDs.
194+
pub fn take_group<'a: 'b, 'b, F>(&'a mut self, include: F) -> Option<BatchGroupWithIds<'a, 'b>>
151195
where
152196
F: Fn(usize) -> Option<(&'b [u8], u32)>,
153197
{
154198
let addrs = &mut self.addrs;
155199
let mut ids = self.group_ids.get_as_new();
156200

201+
self.group_removed.clear();
202+
157203
while ids.is_empty() {
158204
// Find the next addr with items
159205
while self.addr_index < addrs.len() && addrs[self.addr_index].keys.is_empty() {
@@ -163,14 +209,11 @@ impl Batch {
163209
// If all are empty, we're done
164210
if self.addr_index == addrs.len() {
165211
assert!(self.nodes.is_empty());
166-
return None;
212+
break;
167213
}
168214

169215
let keys = &mut addrs[self.addr_index].keys;
170216

171-
self.last_group_ckeys.clear();
172-
ids.clear();
173-
174217
// Get ids/seqs
175218
while ids.len() < zhttppacket::IDS_MAX {
176219
let nkey = match keys.pop_front(&mut self.nodes) {
@@ -179,27 +222,42 @@ impl Batch {
179222
};
180223

181224
let ckey = self.nodes[nkey].value;
182-
self.nodes.remove(nkey);
183225

184-
if let Some((id, seq)) = get_id(ckey) {
185-
self.last_group_ckeys.push(ckey);
226+
let included = if let Some((id, seq)) = include(ckey) {
186227
ids.push(zhttppacket::Id { id, seq: Some(seq) });
187-
}
228+
229+
true
230+
} else {
231+
false
232+
};
233+
234+
self.nodes.remove(nkey);
235+
self.group_removed.push((ckey, included));
188236
}
189237
}
190238

191-
let ai = &addrs[self.addr_index];
239+
if self.group_removed.is_empty() {
240+
assert!(ids.is_empty());
241+
return None;
242+
}
243+
244+
let (addr, use_router): (&[u8], bool) = if !ids.is_empty() {
245+
let ai = &addrs[self.addr_index];
246+
247+
(&ai.addr, ai.use_router)
248+
} else {
249+
(b"", false)
250+
};
192251

193-
Some(BatchGroup {
194-
addr: &ai.addr,
195-
use_router: ai.use_router,
252+
Some(BatchGroupWithIds {
253+
inner: BatchGroup {
254+
addr,
255+
use_router,
256+
removed: &self.group_removed,
257+
},
196258
ids,
197259
})
198260
}
199-
200-
pub fn last_group_ckeys(&self) -> &[usize] {
201-
&self.last_group_ckeys
202-
}
203261
}
204262

205263
#[cfg(test)]
@@ -213,7 +271,6 @@ mod tests {
213271

214272
assert_eq!(batch.capacity(), 4);
215273
assert_eq!(batch.len(), 0);
216-
assert!(batch.last_group_ckeys().is_empty());
217274

218275
assert!(batch.add(b"addr-a", false, 1).is_ok());
219276
assert!(batch.add(b"addr-a", false, 2).is_ok());
@@ -235,9 +292,9 @@ mod tests {
235292
assert_eq!(group.ids()[1].seq, Some(0));
236293
assert_eq!(group.addr(), b"addr-a");
237294
assert!(!group.use_router());
295+
assert_eq!(group.removed(), &[(1, true), (2, true)]);
238296
drop(group);
239297
assert_eq!(batch.is_empty(), false);
240-
assert_eq!(batch.last_group_ckeys(), &[1, 2]);
241298

242299
let group = batch
243300
.take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0)))
@@ -247,9 +304,9 @@ mod tests {
247304
assert_eq!(group.ids()[0].seq, Some(0));
248305
assert_eq!(group.addr(), b"addr-b");
249306
assert!(!group.use_router());
307+
assert_eq!(group.removed(), &[(3, true)]);
250308
drop(group);
251309
assert_eq!(batch.is_empty(), false);
252-
assert_eq!(batch.last_group_ckeys(), &[3]);
253310

254311
let group = batch
255312
.take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0)))
@@ -259,14 +316,13 @@ mod tests {
259316
assert_eq!(group.ids()[0].seq, Some(0));
260317
assert_eq!(group.addr(), b"addr-b");
261318
assert!(group.use_router());
319+
assert_eq!(group.removed(), &[(4, true)]);
262320
drop(group);
263321
assert_eq!(batch.is_empty(), true);
264-
assert_eq!(batch.last_group_ckeys(), &[4]);
265322

266323
assert!(batch
267324
.take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0)))
268325
.is_none());
269-
assert_eq!(batch.last_group_ckeys(), &[4]);
270326
}
271327

272328
#[test]
@@ -287,6 +343,7 @@ mod tests {
287343
assert_eq!(group.ids()[0].id, b"id-2");
288344
assert_eq!(group.ids()[0].seq, Some(0));
289345
assert_eq!(group.addr(), b"addr-b");
346+
assert_eq!(group.removed(), &[(2, true)]);
290347
drop(group);
291348
assert_eq!(batch.is_empty(), true);
292349

@@ -301,6 +358,7 @@ mod tests {
301358
assert_eq!(group.ids()[0].id, b"id-3");
302359
assert_eq!(group.ids()[0].seq, Some(0));
303360
assert_eq!(group.addr(), b"addr-a");
361+
assert_eq!(group.removed(), &[(3, true)]);
304362
drop(group);
305363
assert_eq!(batch.is_empty(), true);
306364
}
@@ -327,6 +385,7 @@ mod tests {
327385
assert_eq!(group.ids()[0].id, b"id-3");
328386
assert_eq!(group.ids()[0].seq, Some(0));
329387
assert_eq!(group.addr(), b"addr-b");
388+
assert_eq!(group.removed(), &[(1, false), (2, false), (3, true)]);
330389
drop(group);
331390
assert_eq!(batch.is_empty(), true);
332391
}

src/connmgr/client.rs

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*
22
* Copyright (C) 2023 Fanout, Inc.
3-
* Copyright (C) 2023-2025 Fastly, Inc.
3+
* Copyright (C) 2023-2026 Fastly, Inc.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
use crate::connmgr::batch::{Batch, BatchKey};
18+
use crate::connmgr::batch::{Batch, BatchGroupWithIds, BatchKey};
1919
use crate::connmgr::connection::{
2020
client_req_connection, client_stream_connection, make_zhttp_response, ConnectionPool,
2121
StreamSharedData,
@@ -116,6 +116,7 @@ fn async_local_channel<T>(
116116
(s, r)
117117
}
118118

119+
#[derive(Copy, Clone)]
119120
enum BatchType {
120121
KeepAlive,
121122
Cancel,
@@ -145,6 +146,29 @@ impl<T> ChannelPool<T> {
145146
}
146147
}
147148

149+
fn make_batch_response(
150+
from: &str,
151+
btype: BatchType,
152+
group: &BatchGroupWithIds,
153+
) -> Result<(Option<ArrayVec<u8, 64>>, zmq::Message), io::Error> {
154+
assert!(group.ids().len() <= zhttppacket::IDS_MAX);
155+
156+
let zresp = zhttppacket::Response {
157+
from: from.as_bytes(),
158+
ids: group.ids(),
159+
multi: true,
160+
ptype: match btype {
161+
BatchType::KeepAlive => zhttppacket::ResponsePacket::KeepAlive,
162+
BatchType::Cancel => zhttppacket::ResponsePacket::Cancel,
163+
},
164+
ptype_str: "",
165+
};
166+
167+
let mut scratch = [0; BULK_PACKET_SIZE_MAX];
168+
169+
make_zhttp_response(group.addr(), group.use_router(), zresp, &mut scratch)
170+
}
171+
148172
struct ConnectionDone {
149173
ckey: usize,
150174
}
@@ -399,7 +423,8 @@ impl Connections {
399423
let nodes = &mut items.nodes;
400424
let batch = &mut items.batch;
401425

402-
while !batch.is_empty() {
426+
loop {
427+
// Wrap in a block to avoid lifetime extension
403428
let group = {
404429
let group = batch.take_group(|ckey| {
405430
let ci = &nodes[ckey].value;
@@ -418,47 +443,36 @@ impl Connections {
418443

419444
match group {
420445
Some(group) => group,
421-
None => continue,
446+
None => break,
422447
}
423448
};
424449

425450
let count = group.ids().len();
451+
let mut to_send = None;
426452

427-
assert!(count <= zhttppacket::IDS_MAX);
428-
429-
let zresp = zhttppacket::Response {
430-
from: from.as_bytes(),
431-
ids: group.ids(),
432-
multi: true,
433-
ptype: match btype {
434-
BatchType::KeepAlive => zhttppacket::ResponsePacket::KeepAlive,
435-
BatchType::Cancel => zhttppacket::ResponsePacket::Cancel,
436-
},
437-
ptype_str: "",
438-
};
453+
if count > 0 {
454+
match make_batch_response(from, btype, &group) {
455+
Ok(ret) => to_send = Some(ret),
456+
Err(e) => error!("failed to serialize batched packet with {count} ids: {e}"),
457+
}
458+
}
439459

440-
let mut scratch = [0; BULK_PACKET_SIZE_MAX];
460+
let group = group.discard_ids();
441461

442-
let (addr, msg) =
443-
match make_zhttp_response(group.addr(), group.use_router(), zresp, &mut scratch) {
444-
Ok(resp) => resp,
445-
Err(e) => {
446-
error!(
447-
"failed to serialize keep-alive packet with {} ids: {}",
448-
count, e
449-
);
450-
continue;
451-
}
452-
};
462+
// Before we do anything that might fail, let's clear the batch keys
463+
for &(ckey, _) in group.removed() {
464+
let ci = &mut nodes[ckey].value;
465+
ci.batch_key = None;
466+
}
453467

454-
drop(group);
468+
let Some((addr, msg)) = to_send else {
469+
continue;
470+
};
455471

456-
for &ckey in batch.last_group_ckeys() {
472+
for &(ckey, _) in group.removed().iter().filter(|(_, included)| *included) {
457473
let ci = &mut nodes[ckey].value;
458474
let cshared = ci.shared.as_ref().unwrap();
459-
460475
cshared.inc_out_seq();
461-
ci.batch_key = None;
462476
}
463477

464478
return Some((count, addr, msg));

0 commit comments

Comments
 (0)