Skip to content

Commit fb0f307

Browse files
committed
switch Vec[u8] to bytes::Bytes
1 parent f8b14b8 commit fb0f307

File tree

6 files changed

+46
-44
lines changed

6 files changed

+46
-44
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ readme = "README.md"
1111

1212
[dependencies]
1313
ahash = "0.8"
14+
bytes = "1"
1415
etherparse = { version = "0.18", default-features = false, features = ["std"] }
1516
log = { version = "0.4", default-features = false }
1617
rand = { version = "0.9", default-features = false, features = ["thread_rng"] }

src/packet.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub(crate) enum TransportHeader {
5050
pub struct NetworkPacket {
5151
pub(crate) ip: IpHeader,
5252
pub(crate) transport: TransportHeader,
53-
pub(crate) payload: Option<Vec<u8>>,
53+
pub(crate) payload: Option<bytes::Bytes>,
5454
}
5555

5656
impl NetworkPacket {
@@ -68,7 +68,7 @@ impl NetworkPacket {
6868
Some(etherparse::TransportSlice::Udp(u)) => (TransportHeader::Udp(u.to_header()), u.payload()),
6969
_ => (TransportHeader::Unknown, ip_payload),
7070
};
71-
let payload = if payload.is_empty() { None } else { Some(payload.to_vec()) };
71+
let payload = if payload.is_empty() { None } else { Some(payload.to_vec().into()) };
7272

7373
Ok(NetworkPacket { ip, transport, payload })
7474
}

src/stream/tcb.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub(crate) struct Tcb {
5252
send_window: u16,
5353
state: TcpState,
5454
inflight_packets: BTreeMap<SeqNum, InflightPacket>,
55-
unordered_packets: BTreeMap<SeqNum, Vec<u8>>,
55+
unordered_packets: BTreeMap<SeqNum, bytes::Bytes>,
5656
duplicate_ack_count: usize,
5757
duplicate_ack_count_helper: SeqNum,
5858
}
@@ -97,7 +97,7 @@ impl Tcb {
9797
self.duplicate_ack_count >= MAX_COUNT_FOR_DUP_ACK
9898
}
9999

100-
pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec<u8>) {
100+
pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: bytes::Bytes) {
101101
if seq < self.ack {
102102
#[rustfmt::skip]
103103
log::warn!("{:?}: Received packet seq {seq} < self ack {}, len = {}", self.state, self.ack, buf.len());
@@ -113,8 +113,8 @@ impl Tcb {
113113
self.unordered_packets.values().map(|p| p.len()).sum()
114114
}
115115

116-
pub(super) fn consume_unordered_packets(&mut self, max_bytes: usize) -> Option<Vec<u8>> {
117-
let mut data = Vec::new();
116+
pub(super) fn consume_unordered_packets(&mut self, max_bytes: usize) -> Option<bytes::Bytes> {
117+
let mut data = bytes::BytesMut::new();
118118
let mut remaining_bytes = max_bytes;
119119

120120
while remaining_bytes > 0 {
@@ -145,7 +145,7 @@ impl Tcb {
145145
}
146146
}
147147

148-
if data.is_empty() { None } else { Some(data) }
148+
if data.is_empty() { None } else { Some(data.freeze()) }
149149
}
150150

151151
pub(super) fn increase_seq(&mut self) {
@@ -229,7 +229,7 @@ impl Tcb {
229229
res
230230
}
231231

232-
pub(super) fn add_inflight_packet(&mut self, buf: Vec<u8>) -> std::io::Result<()> {
232+
pub(super) fn add_inflight_packet(&mut self, buf: bytes::Bytes) -> std::io::Result<()> {
233233
if buf.is_empty() {
234234
return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Empty payload"));
235235
}
@@ -258,7 +258,7 @@ impl Tcb {
258258
let mut inflight_packet = self.inflight_packets.remove(&seq).unwrap();
259259
let distance = ack.distance(inflight_packet.seq) as usize;
260260
if distance < inflight_packet.payload.len() {
261-
inflight_packet.payload.drain(0..distance);
261+
inflight_packet.payload = inflight_packet.payload.split_off(distance);
262262
inflight_packet.seq = ack;
263263
self.inflight_packets.insert(ack, inflight_packet);
264264
}
@@ -307,14 +307,14 @@ impl Tcb {
307307
#[derive(Debug, Clone)]
308308
pub struct InflightPacket {
309309
pub seq: SeqNum,
310-
pub payload: Vec<u8>,
310+
pub payload: bytes::Bytes,
311311
pub send_time: std::time::Instant,
312312
pub retransmit_count: usize,
313313
pub retransmit_timeout: std::time::Duration, // current retransmission timeout
314314
}
315315

316316
impl InflightPacket {
317-
fn new(seq: SeqNum, payload: Vec<u8>) -> Self {
317+
fn new(seq: SeqNum, payload: bytes::Bytes) -> Self {
318318
Self {
319319
seq,
320320
payload,
@@ -337,7 +337,7 @@ mod tests {
337337

338338
#[test]
339339
fn test_in_flight_packet() {
340-
let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50]);
340+
let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50].into());
341341

342342
assert!(p.contains_seq_num((u32::MAX - 1).into()));
343343
assert!(p.contains_seq_num(u32::MAX.into()));
@@ -353,9 +353,9 @@ mod tests {
353353
let mut tcb = Tcb::new(SeqNum(1000), 1500);
354354

355355
// insert 3 consecutive packets
356-
tcb.add_unordered_packet(SeqNum(1000), vec![1; 500]); // seq=1000, len=500
357-
tcb.add_unordered_packet(SeqNum(1500), vec![2; 500]); // seq=1500, len=500
358-
tcb.add_unordered_packet(SeqNum(2000), vec![3; 500]); // seq=2000, len=500
356+
tcb.add_unordered_packet(SeqNum(1000), vec![1; 500].into()); // seq=1000, len=500
357+
tcb.add_unordered_packet(SeqNum(1500), vec![2; 500].into()); // seq=1500, len=500
358+
tcb.add_unordered_packet(SeqNum(2000), vec![3; 500].into()); // seq=2000, len=500
359359

360360
// test 1: extract up to 700 bytes
361361
let data = tcb.consume_unordered_packets(700).unwrap();
@@ -386,9 +386,9 @@ mod tests {
386386
tcb.seq = SeqNum(100); // setting the initial seq
387387

388388
// insert 3 consecutive packets
389-
tcb.add_inflight_packet(vec![1; 500]).unwrap(); // seq=100, len=500
390-
tcb.add_inflight_packet(vec![2; 500]).unwrap(); // seq=600, len=500
391-
tcb.add_inflight_packet(vec![3; 500]).unwrap(); // seq=1100, len=500
389+
tcb.add_inflight_packet(vec![1; 500].into()).unwrap(); // seq=100, len=500
390+
tcb.add_inflight_packet(vec![2; 500].into()).unwrap(); // seq=600, len=500
391+
tcb.add_inflight_packet(vec![3; 500].into()).unwrap(); // seq=1100, len=500
392392

393393
// test 1: confirm partial packets (ack=800)
394394
tcb.update_inflight_packet_queue(SeqNum(800));
@@ -410,9 +410,9 @@ mod tests {
410410
tcb.seq = SeqNum(1000);
411411

412412
// Insert 3 consecutive packets
413-
tcb.add_inflight_packet(vec![1; 500]).unwrap(); // seq=1000, len=500
414-
tcb.add_inflight_packet(vec![2; 500]).unwrap(); // seq=1500, len=500
415-
tcb.add_inflight_packet(vec![3; 500]).unwrap(); // seq=2000, len=500
413+
tcb.add_inflight_packet(vec![1; 500].into()).unwrap(); // seq=1000, len=500
414+
tcb.add_inflight_packet(vec![2; 500].into()).unwrap(); // seq=1500, len=500
415+
tcb.add_inflight_packet(vec![3; 500].into()).unwrap(); // seq=2000, len=500
416416

417417
// Emulate cumulative ACK: ack=2500
418418
tcb.update_inflight_packet_queue(SeqNum(2500));
@@ -423,7 +423,7 @@ mod tests {
423423
fn test_retransmit_with_exponential_backoff() {
424424
let mut tcb = Tcb::new(SeqNum(1000), 1500);
425425

426-
tcb.add_inflight_packet(vec![1; 500]).unwrap();
426+
tcb.add_inflight_packet(vec![1; 500].into()).unwrap();
427427

428428
// Simulate retransmission timeouts
429429
for i in 0..MAX_RETRANSMIT_COUNT {

src/stream/tcp.rs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ pub struct IpStackTcpStream {
8282
destroy_messenger: Option<::tokio::sync::oneshot::Sender<()>>,
8383
timeout: Pin<Box<tokio::time::Sleep>>,
8484
timeout_interval: Duration,
85-
data_tx: tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
86-
data_rx: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
85+
data_tx: tokio::sync::mpsc::UnboundedSender<bytes::Bytes>,
86+
data_rx: tokio::sync::mpsc::UnboundedReceiver<bytes::Bytes>,
8787
read_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
8888
task_handle: Option<tokio::task::JoinHandle<std::io::Result<()>>>,
8989
exit_notifier: Option<tokio::sync::mpsc::Sender<()>>,
90-
temp_read_buffer: Vec<u8>,
90+
temp_read_buffer: bytes::Bytes,
9191
}
9292

9393
impl IpStackTcpStream {
@@ -113,7 +113,7 @@ impl IpStackTcpStream {
113113
}
114114

115115
let (stream_sender, stream_receiver) = tokio::sync::mpsc::unbounded_channel::<NetworkPacket>();
116-
let (data_tx, data_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
116+
let (data_tx, data_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
117117
let deadline = tokio::time::Instant::now() + timeout_interval;
118118

119119
let mut stream = IpStackTcpStream {
@@ -133,7 +133,7 @@ impl IpStackTcpStream {
133133
read_notify: std::sync::Arc::new(std::sync::Mutex::new(None)),
134134
task_handle: None,
135135
exit_notifier: None,
136-
temp_read_buffer: Vec::new(),
136+
temp_read_buffer: bytes::Bytes::new(),
137137
};
138138

139139
let sessions = SESSION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst).saturating_add(1);
@@ -171,7 +171,7 @@ impl AsyncRead for IpStackTcpStream {
171171
if !self.temp_read_buffer.is_empty() {
172172
let len = std::cmp::min(buf.remaining(), self.temp_read_buffer.len());
173173
buf.put_slice(&self.temp_read_buffer[..len]);
174-
self.temp_read_buffer.drain(..len); // remove the read data from the temp buffer
174+
self.temp_read_buffer = self.temp_read_buffer.split_off(len); // remove the read data from the temp buffer
175175
return Poll::Ready(Ok(()));
176176
}
177177

@@ -212,7 +212,7 @@ impl AsyncRead for IpStackTcpStream {
212212
} else {
213213
// if `buf` is not enough, put the remaining data into the temp buffer
214214
buf.put_slice(&data[..capacity]);
215-
self.temp_read_buffer.extend_from_slice(&data[capacity..]);
215+
self.temp_read_buffer = bytes::Bytes::copy_from_slice(&data[capacity..]);
216216
}
217217
Poll::Ready(Ok(()))
218218
}
@@ -251,8 +251,9 @@ impl AsyncWrite for IpStackTcpStream {
251251

252252
let mut tcb = self.tcb.lock().unwrap();
253253
let sender = &self.up_packet_sender;
254-
let payload_len = write_packet_to_device(sender, nt, &tcb, ACK | PSH, None, Some(buf.to_vec()))?;
255-
tcb.add_inflight_packet(buf[..payload_len].to_vec())?;
254+
let buf: bytes::Bytes = buf.to_vec().into();
255+
let payload_len = write_packet_to_device(sender, nt, &tcb, ACK | PSH, None, Some(buf.clone()))?;
256+
tcb.add_inflight_packet(buf[..payload_len].to_vec().into())?;
256257

257258
let (state, seq, ack) = (tcb.get_state(), tcb.get_seq(), tcb.get_ack());
258259
let l_info = format!("local {{ seq: {seq}, ack: {ack} }}");
@@ -388,7 +389,7 @@ async fn tcp_main_logic_loop(
388389
network_tuple: NetworkTuple,
389390
write_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
390391
read_notify: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
391-
data_tx: tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
392+
data_tx: tokio::sync::mpsc::UnboundedSender<bytes::Bytes>,
392393
mut exit_monitor: tokio::sync::mpsc::Receiver<()>,
393394
) -> std::io::Result<()> {
394395
{
@@ -737,7 +738,7 @@ fn extract_data_n_write_upstream(
737738
up_packet_sender: &PacketSender,
738739
tcb: &mut Tcb,
739740
network_tuple: NetworkTuple,
740-
data_tx: &tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
741+
data_tx: &tokio::sync::mpsc::UnboundedSender<bytes::Bytes>,
741742
read_notify: &std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
742743
) -> std::io::Result<()> {
743744
let (state, seq, ack) = (tcb.get_state(), tcb.get_seq(), tcb.get_ack());
@@ -765,7 +766,7 @@ pub(crate) fn write_packet_to_device(
765766
tcb: &Tcb,
766767
flags: u8,
767768
seq: Option<SeqNum>,
768-
payload: Option<Vec<u8>>,
769+
payload: Option<bytes::Bytes>,
769770
) -> std::io::Result<usize> {
770771
use std::io::Error;
771772
let seq = seq.unwrap_or(tcb.get_seq()).0;
@@ -788,7 +789,7 @@ pub(crate) fn create_raw_packet(
788789
seq: u32,
789790
ack: u32,
790791
win: u16,
791-
mut payload: Vec<u8>,
792+
mut payload: bytes::Bytes,
792793
) -> std::io::Result<NetworkPacket> {
793794
let mut tcp_header = etherparse::TcpHeader::new(src_addr.port(), dst_addr.port(), seq, win);
794795
tcp_header.acknowledgment_number = ack;

src/stream/udp.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub struct IpStackUdpStream {
1717
stream_sender: PacketSender,
1818
stream_receiver: PacketReceiver,
1919
up_pkt_sender: PacketSender,
20-
first_payload: Option<Vec<u8>>,
20+
first_payload: Option<bytes::Bytes>,
2121
timeout: Pin<Box<Sleep>>,
2222
timeout_interval: Duration,
2323
mtu: u16,
@@ -28,7 +28,7 @@ impl IpStackUdpStream {
2828
pub fn new(
2929
src_addr: SocketAddr,
3030
dst_addr: SocketAddr,
31-
payload: Vec<u8>,
31+
payload: bytes::Bytes,
3232
up_pkt_sender: PacketSender,
3333
mtu: u16,
3434
timeout_interval: Duration,
@@ -54,7 +54,7 @@ impl IpStackUdpStream {
5454
self.stream_sender.clone()
5555
}
5656

57-
fn create_rev_packet(&self, ttl: u8, mut payload: Vec<u8>) -> std::io::Result<NetworkPacket> {
57+
fn create_rev_packet(&self, ttl: u8, mut payload: bytes::Bytes) -> std::io::Result<NetworkPacket> {
5858
const UHS: usize = 8; // udp header size is 8
5959
match (self.dst_addr.ip(), self.src_addr.ip()) {
6060
(std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
@@ -143,7 +143,7 @@ impl AsyncRead for IpStackUdpStream {
143143
impl AsyncWrite for IpStackUdpStream {
144144
fn poll_write(mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8]) -> std::task::Poll<std::io::Result<usize>> {
145145
self.reset_timeout();
146-
let packet = self.create_rev_packet(TTL, buf.to_vec())?;
146+
let packet = self.create_rev_packet(TTL, buf.to_vec().into())?;
147147
let payload_len = packet.payload.as_ref().map(|p| p.len()).unwrap_or(0);
148148
self.up_pkt_sender.send(packet).or(Err(std::io::ErrorKind::UnexpectedEof))?;
149149
std::task::Poll::Ready(Ok(payload_len))

src/stream/unknown.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ use std::net::IpAddr;
88
pub struct IpStackUnknownTransport {
99
src_addr: IpAddr,
1010
dst_addr: IpAddr,
11-
payload: Vec<u8>,
11+
payload: bytes::Bytes,
1212
protocol: IpNumber,
1313
mtu: u16,
1414
packet_sender: PacketSender,
1515
}
1616

1717
impl IpStackUnknownTransport {
18-
pub(crate) fn new(src_addr: IpAddr, dst_addr: IpAddr, payload: Vec<u8>, ip: &IpHeader, mtu: u16, packet_sender: PacketSender) -> Self {
18+
pub(crate) fn new(src_addr: IpAddr, dst_addr: IpAddr, payload: bytes::Bytes, ip: &IpHeader, mtu: u16, sender: PacketSender) -> Self {
1919
let protocol = match ip {
2020
IpHeader::Ipv4(ip) => ip.protocol,
2121
IpHeader::Ipv6(ip) => ip.next_header,
@@ -26,7 +26,7 @@ impl IpStackUnknownTransport {
2626
payload,
2727
protocol,
2828
mtu,
29-
packet_sender,
29+
packet_sender: sender,
3030
}
3131
}
3232
pub fn src_addr(&self) -> IpAddr {
@@ -68,7 +68,7 @@ impl IpStackUnknownTransport {
6868
Ok(NetworkPacket {
6969
ip: IpHeader::Ipv4(ip_h),
7070
transport: TransportHeader::Unknown,
71-
payload: Some(p),
71+
payload: Some(p.into()),
7272
})
7373
}
7474
(std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
@@ -91,7 +91,7 @@ impl IpStackUnknownTransport {
9191
Ok(NetworkPacket {
9292
ip: IpHeader::Ipv6(ip_h),
9393
transport: TransportHeader::Unknown,
94-
payload: Some(p),
94+
payload: Some(p.into()),
9595
})
9696
}
9797
_ => unreachable!(),

0 commit comments

Comments
 (0)