Skip to content

Commit 83bcbc3

Browse files
committed
switch Vec[u8] to bytes::Bytes
1 parent f8b14b8 commit 83bcbc3

File tree

6 files changed

+47
-44
lines changed

6 files changed

+47
-44
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ readme = "README.md"
1111

1212
[dependencies]
1313
ahash = "0.8"
14+
bytes = "1"
1415
etherparse = { version = "0.18", default-features = false, features = ["std"] }
1516
log = { version = "0.4", default-features = false }
1617
rand = { version = "0.9", default-features = false, features = ["thread_rng"] }

src/packet.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub(crate) enum TransportHeader {
5050
pub struct NetworkPacket {
5151
pub(crate) ip: IpHeader,
5252
pub(crate) transport: TransportHeader,
53-
pub(crate) payload: Option<Vec<u8>>,
53+
pub(crate) payload: Option<bytes::Bytes>,
5454
}
5555

5656
impl NetworkPacket {
@@ -68,7 +68,7 @@ impl NetworkPacket {
6868
Some(etherparse::TransportSlice::Udp(u)) => (TransportHeader::Udp(u.to_header()), u.payload()),
6969
_ => (TransportHeader::Unknown, ip_payload),
7070
};
71-
let payload = if payload.is_empty() { None } else { Some(payload.to_vec()) };
71+
let payload = if payload.is_empty() { None } else { Some(payload.to_vec().into()) };
7272

7373
Ok(NetworkPacket { ip, transport, payload })
7474
}

src/stream/tcb.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::seqnum::SeqNum;
2+
use bytes::{Bytes, BytesMut};
23
use etherparse::TcpHeader;
34
use std::collections::BTreeMap;
45

@@ -52,7 +53,7 @@ pub(crate) struct Tcb {
5253
send_window: u16,
5354
state: TcpState,
5455
inflight_packets: BTreeMap<SeqNum, InflightPacket>,
55-
unordered_packets: BTreeMap<SeqNum, Vec<u8>>,
56+
unordered_packets: BTreeMap<SeqNum, Bytes>,
5657
duplicate_ack_count: usize,
5758
duplicate_ack_count_helper: SeqNum,
5859
}
@@ -97,7 +98,7 @@ impl Tcb {
9798
self.duplicate_ack_count >= MAX_COUNT_FOR_DUP_ACK
9899
}
99100

100-
pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec<u8>) {
101+
pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Bytes) {
101102
if seq < self.ack {
102103
#[rustfmt::skip]
103104
log::warn!("{:?}: Received packet seq {seq} < self ack {}, len = {}", self.state, self.ack, buf.len());
@@ -113,8 +114,8 @@ impl Tcb {
113114
self.unordered_packets.values().map(|p| p.len()).sum()
114115
}
115116

116-
pub(super) fn consume_unordered_packets(&mut self, max_bytes: usize) -> Option<Vec<u8>> {
117-
let mut data = Vec::new();
117+
pub(super) fn consume_unordered_packets(&mut self, max_bytes: usize) -> Option<Bytes> {
118+
let mut data = BytesMut::new();
118119
let mut remaining_bytes = max_bytes;
119120

120121
while remaining_bytes > 0 {
@@ -145,7 +146,7 @@ impl Tcb {
145146
}
146147
}
147148

148-
if data.is_empty() { None } else { Some(data) }
149+
if data.is_empty() { None } else { Some(data.freeze()) }
149150
}
150151

151152
pub(super) fn increase_seq(&mut self) {
@@ -229,7 +230,7 @@ impl Tcb {
229230
res
230231
}
231232

232-
pub(super) fn add_inflight_packet(&mut self, buf: Vec<u8>) -> std::io::Result<()> {
233+
pub(super) fn add_inflight_packet(&mut self, buf: Bytes) -> std::io::Result<()> {
233234
if buf.is_empty() {
234235
return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Empty payload"));
235236
}
@@ -258,7 +259,7 @@ impl Tcb {
258259
let mut inflight_packet = self.inflight_packets.remove(&seq).unwrap();
259260
let distance = ack.distance(inflight_packet.seq) as usize;
260261
if distance < inflight_packet.payload.len() {
261-
inflight_packet.payload.drain(0..distance);
262+
inflight_packet.payload = inflight_packet.payload.split_off(distance);
262263
inflight_packet.seq = ack;
263264
self.inflight_packets.insert(ack, inflight_packet);
264265
}
@@ -307,14 +308,14 @@ impl Tcb {
307308
#[derive(Debug, Clone)]
308309
pub struct InflightPacket {
309310
pub seq: SeqNum,
310-
pub payload: Vec<u8>,
311+
pub payload: Bytes,
311312
pub send_time: std::time::Instant,
312313
pub retransmit_count: usize,
313314
pub retransmit_timeout: std::time::Duration, // current retransmission timeout
314315
}
315316

316317
impl InflightPacket {
317-
fn new(seq: SeqNum, payload: Vec<u8>) -> Self {
318+
fn new(seq: SeqNum, payload: Bytes) -> Self {
318319
Self {
319320
seq,
320321
payload,
@@ -337,7 +338,7 @@ mod tests {
337338

338339
#[test]
339340
fn test_in_flight_packet() {
340-
let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50]);
341+
let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50].into());
341342

342343
assert!(p.contains_seq_num((u32::MAX - 1).into()));
343344
assert!(p.contains_seq_num(u32::MAX.into()));
@@ -353,9 +354,9 @@ mod tests {
353354
let mut tcb = Tcb::new(SeqNum(1000), 1500);
354355

355356
// insert 3 consecutive packets
356-
tcb.add_unordered_packet(SeqNum(1000), vec![1; 500]); // seq=1000, len=500
357-
tcb.add_unordered_packet(SeqNum(1500), vec![2; 500]); // seq=1500, len=500
358-
tcb.add_unordered_packet(SeqNum(2000), vec![3; 500]); // seq=2000, len=500
357+
tcb.add_unordered_packet(SeqNum(1000), vec![1; 500].into()); // seq=1000, len=500
358+
tcb.add_unordered_packet(SeqNum(1500), vec![2; 500].into()); // seq=1500, len=500
359+
tcb.add_unordered_packet(SeqNum(2000), vec![3; 500].into()); // seq=2000, len=500
359360

360361
// test 1: extract up to 700 bytes
361362
let data = tcb.consume_unordered_packets(700).unwrap();
@@ -386,9 +387,9 @@ mod tests {
386387
tcb.seq = SeqNum(100); // setting the initial seq
387388

388389
// insert 3 consecutive packets
389-
tcb.add_inflight_packet(vec![1; 500]).unwrap(); // seq=100, len=500
390-
tcb.add_inflight_packet(vec![2; 500]).unwrap(); // seq=600, len=500
391-
tcb.add_inflight_packet(vec![3; 500]).unwrap(); // seq=1100, len=500
390+
tcb.add_inflight_packet(vec![1; 500].into()).unwrap(); // seq=100, len=500
391+
tcb.add_inflight_packet(vec![2; 500].into()).unwrap(); // seq=600, len=500
392+
tcb.add_inflight_packet(vec![3; 500].into()).unwrap(); // seq=1100, len=500
392393

393394
// test 1: confirm partial packets (ack=800)
394395
tcb.update_inflight_packet_queue(SeqNum(800));
@@ -410,9 +411,9 @@ mod tests {
410411
tcb.seq = SeqNum(1000);
411412

412413
// Insert 3 consecutive packets
413-
tcb.add_inflight_packet(vec![1; 500]).unwrap(); // seq=1000, len=500
414-
tcb.add_inflight_packet(vec![2; 500]).unwrap(); // seq=1500, len=500
415-
tcb.add_inflight_packet(vec![3; 500]).unwrap(); // seq=2000, len=500
414+
tcb.add_inflight_packet(vec![1; 500].into()).unwrap(); // seq=1000, len=500
415+
tcb.add_inflight_packet(vec![2; 500].into()).unwrap(); // seq=1500, len=500
416+
tcb.add_inflight_packet(vec![3; 500].into()).unwrap(); // seq=2000, len=500
416417

417418
// Emulate cumulative ACK: ack=2500
418419
tcb.update_inflight_packet_queue(SeqNum(2500));
@@ -423,7 +424,7 @@ mod tests {
423424
fn test_retransmit_with_exponential_backoff() {
424425
let mut tcb = Tcb::new(SeqNum(1000), 1500);
425426

426-
tcb.add_inflight_packet(vec![1; 500]).unwrap();
427+
tcb.add_inflight_packet(vec![1; 500].into()).unwrap();
427428

428429
// Simulate retransmission timeouts
429430
for i in 0..MAX_RETRANSMIT_COUNT {

src/stream/tcp.rs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ pub struct IpStackTcpStream {
8282
destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>,
8383
timeout: Pin<Box<tokio::time::Sleep>>,
8484
timeout_interval: Duration,
85-
data_tx: tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
86-
data_rx: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
85+
data_tx: tokio::sync::mpsc::UnboundedSender<bytes::Bytes>,
86+
data_rx: tokio::sync::mpsc::UnboundedReceiver<bytes::Bytes>,
8787
read_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
8888
task_handle: Option<tokio::task::JoinHandle<std::io::Result<()>>>,
8989
exit_notifier: Option<tokio::sync::mpsc::Sender<()>>,
90-
temp_read_buffer: Vec<u8>,
90+
temp_read_buffer: bytes::Bytes,
9191
}
9292

9393
impl IpStackTcpStream {
@@ -113,7 +113,7 @@ impl IpStackTcpStream {
113113
}
114114

115115
let (stream_sender, stream_receiver) = tokio::sync::mpsc::unbounded_channel::<NetworkPacket>();
116-
let (data_tx, data_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
116+
let (data_tx, data_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
117117
let deadline = tokio::time::Instant::now() + timeout_interval;
118118

119119
let mut stream = IpStackTcpStream {
@@ -133,7 +133,7 @@ impl IpStackTcpStream {
133133
read_notify: std::sync::Arc::new(std::sync::Mutex::new(None)),
134134
task_handle: None,
135135
exit_notifier: None,
136-
temp_read_buffer: Vec::new(),
136+
temp_read_buffer: bytes::Bytes::new(),
137137
};
138138

139139
let sessions = SESSION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst).saturating_add(1);
@@ -171,7 +171,7 @@ impl AsyncRead for IpStackTcpStream {
171171
if !self.temp_read_buffer.is_empty() {
172172
let len = std::cmp::min(buf.remaining(), self.temp_read_buffer.len());
173173
buf.put_slice(&self.temp_read_buffer[..len]);
174-
self.temp_read_buffer.drain(..len); // remove the read data from the temp buffer
174+
self.temp_read_buffer = self.temp_read_buffer.split_off(len); // remove the read data from the temp buffer
175175
return Poll::Ready(Ok(()));
176176
}
177177

@@ -212,7 +212,7 @@ impl AsyncRead for IpStackTcpStream {
212212
} else {
213213
// if `buf` is not enough, put the remaining data into the temp buffer
214214
buf.put_slice(&data[..capacity]);
215-
self.temp_read_buffer.extend_from_slice(&data[capacity..]);
215+
self.temp_read_buffer = bytes::Bytes::copy_from_slice(&data[capacity..]);
216216
}
217217
Poll::Ready(Ok(()))
218218
}
@@ -251,8 +251,9 @@ impl AsyncWrite for IpStackTcpStream {
251251

252252
let mut tcb = self.tcb.lock().unwrap();
253253
let sender = &self.up_packet_sender;
254-
let payload_len = write_packet_to_device(sender, nt, &tcb, ACK | PSH, None, Some(buf.to_vec()))?;
255-
tcb.add_inflight_packet(buf[..payload_len].to_vec())?;
254+
let buf: bytes::Bytes = buf.to_vec().into();
255+
let payload_len = write_packet_to_device(sender, nt, &tcb, ACK | PSH, None, Some(buf.clone()))?;
256+
tcb.add_inflight_packet(buf[..payload_len].to_vec().into())?;
256257

257258
let (state, seq, ack) = (tcb.get_state(), tcb.get_seq(), tcb.get_ack());
258259
let l_info = format!("local {{ seq: {seq}, ack: {ack} }}");
@@ -388,7 +389,7 @@ async fn tcp_main_logic_loop(
388389
network_tuple: NetworkTuple,
389390
write_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
390391
read_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
391-
data_tx: tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
392+
data_tx: tokio::sync::mpsc::UnboundedSender<bytes::Bytes>,
392393
mut exit_monitor: tokio::sync::mpsc::Receiver<()>,
393394
) -> std::io::Result<()> {
394395
{
@@ -737,7 +738,7 @@ fn extract_data_n_write_upstream(
737738
up_packet_sender: &PacketSender,
738739
tcb: &mut Tcb,
739740
network_tuple: NetworkTuple,
740-
data_tx: &tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
741+
data_tx: &tokio::sync::mpsc::UnboundedSender<bytes::Bytes>,
741742
read_notify: &std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
742743
) -> std::io::Result<()> {
743744
let (state, seq, ack) = (tcb.get_state(), tcb.get_seq(), tcb.get_ack());
@@ -765,7 +766,7 @@ pub(crate) fn write_packet_to_device(
765766
tcb: &Tcb,
766767
flags: u8,
767768
seq: Option<SeqNum>,
768-
payload: Option<Vec<u8>>,
769+
payload: Option<bytes::Bytes>,
769770
) -> std::io::Result<usize> {
770771
use std::io::Error;
771772
let seq = seq.unwrap_or(tcb.get_seq()).0;
@@ -788,7 +789,7 @@ pub(crate) fn create_raw_packet(
788789
seq: u32,
789790
ack: u32,
790791
win: u16,
791-
mut payload: Vec<u8>,
792+
mut payload: bytes::Bytes,
792793
) -> std::io::Result<NetworkPacket> {
793794
let mut tcp_header = etherparse::TcpHeader::new(src_addr.port(), dst_addr.port(), seq, win);
794795
tcp_header.acknowledgment_number = ack;

src/stream/udp.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub struct IpStackUdpStream {
1717
stream_sender: PacketSender,
1818
stream_receiver: PacketReceiver,
1919
up_pkt_sender: PacketSender,
20-
first_payload: Option<Vec<u8>>,
20+
first_payload: Option<bytes::Bytes>,
2121
timeout: Pin<Box<Sleep>>,
2222
timeout_interval: Duration,
2323
mtu: u16,
@@ -28,7 +28,7 @@ impl IpStackUdpStream {
2828
pub fn new(
2929
src_addr: SocketAddr,
3030
dst_addr: SocketAddr,
31-
payload: Vec<u8>,
31+
payload: bytes::Bytes,
3232
up_pkt_sender: PacketSender,
3333
mtu: u16,
3434
timeout_interval: Duration,
@@ -54,7 +54,7 @@ impl IpStackUdpStream {
5454
self.stream_sender.clone()
5555
}
5656

57-
fn create_rev_packet(&self, ttl: u8, mut payload: Vec<u8>) -> std::io::Result<NetworkPacket> {
57+
fn create_rev_packet(&self, ttl: u8, mut payload: bytes::Bytes) -> std::io::Result<NetworkPacket> {
5858
const UHS: usize = 8; // udp header size is 8
5959
match (self.dst_addr.ip(), self.src_addr.ip()) {
6060
(std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
@@ -143,7 +143,7 @@ impl AsyncRead for IpStackUdpStream {
143143
impl AsyncWrite for IpStackUdpStream {
144144
fn poll_write(mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8]) -> std::task::Poll<std::io::Result<usize>> {
145145
self.reset_timeout();
146-
let packet = self.create_rev_packet(TTL, buf.to_vec())?;
146+
let packet = self.create_rev_packet(TTL, buf.to_vec().into())?;
147147
let payload_len = packet.payload.as_ref().map(|p| p.len()).unwrap_or(0);
148148
self.up_pkt_sender.send(packet).or(Err(std::io::ErrorKind::UnexpectedEof))?;
149149
std::task::Poll::Ready(Ok(payload_len))

src/stream/unknown.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ use std::net::IpAddr;
88
pub struct IpStackUnknownTransport {
99
src_addr: IpAddr,
1010
dst_addr: IpAddr,
11-
payload: Vec<u8>,
11+
payload: bytes::Bytes,
1212
protocol: IpNumber,
1313
mtu: u16,
1414
packet_sender: PacketSender,
1515
}
1616

1717
impl IpStackUnknownTransport {
18-
pub(crate) fn new(src_addr: IpAddr, dst_addr: IpAddr, payload: Vec<u8>, ip: &IpHeader, mtu: u16, packet_sender: PacketSender) -> Self {
18+
pub(crate) fn new(src_addr: IpAddr, dst_addr: IpAddr, payload: bytes::Bytes, ip: &IpHeader, mtu: u16, sender: PacketSender) -> Self {
1919
let protocol = match ip {
2020
IpHeader::Ipv4(ip) => ip.protocol,
2121
IpHeader::Ipv6(ip) => ip.next_header,
@@ -26,7 +26,7 @@ impl IpStackUnknownTransport {
2626
payload,
2727
protocol,
2828
mtu,
29-
packet_sender,
29+
packet_sender: sender,
3030
}
3131
}
3232
pub fn src_addr(&self) -> IpAddr {
@@ -68,7 +68,7 @@ impl IpStackUnknownTransport {
6868
Ok(NetworkPacket {
6969
ip: IpHeader::Ipv4(ip_h),
7070
transport: TransportHeader::Unknown,
71-
payload: Some(p),
71+
payload: Some(p.into()),
7272
})
7373
}
7474
(std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
@@ -91,7 +91,7 @@ impl IpStackUnknownTransport {
9191
Ok(NetworkPacket {
9292
ip: IpHeader::Ipv6(ip_h),
9393
transport: TransportHeader::Unknown,
94-
payload: Some(p),
94+
payload: Some(p.into()),
9595
})
9696
}
9797
_ => unreachable!(),

0 commit comments

Comments
 (0)