Skip to content

Commit 6808dd4

Browse files
committed
Framing mechanism for messages
1 parent b2668a6 commit 6808dd4

File tree

3 files changed

+113
-110
lines changed

3 files changed

+113
-110
lines changed

packages/network-transport-quic/src/Network/Transport/QUIC/Internal.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,9 @@ handleIncomingMessages ourEndPoint remoteEndPoint =
287287
(writeTQueue ourQueue . ConnectionClosed . connectionId)
288288
)
289289

290-
handleMessage :: ClientConnId -> ByteString -> IO ()
290+
handleMessage :: ClientConnId -> [ByteString] -> IO ()
291291
handleMessage clientConnId payload =
292-
atomically (writeTQueue ourQueue (Received (connectionId clientConnId) [payload]))
292+
atomically (writeTQueue ourQueue (Received (connectionId clientConnId) payload))
293293

294294
prematureExit :: IOException -> IO ()
295295
prematureExit exc = do

packages/network-transport-quic/src/Network/Transport/QUIC/Internal/Messaging.hs

Lines changed: 109 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,40 @@
11
{-# LANGUAGE BangPatterns #-}
22
{-# LANGUAGE DerivingStrategies #-}
33
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4+
{-# LANGUAGE LambdaCase #-}
45
{-# LANGUAGE NumericUnderscores #-}
56
{-# LANGUAGE ScopedTypeVariables #-}
6-
{-# LANGUAGE LambdaCase #-}
77

8-
module Network.Transport.QUIC.Internal.Messaging
9-
( -- * Connections
10-
ServerConnId,
11-
serverSelfConnId,
12-
firstNonReservedServerConnId,
13-
ClientConnId,
14-
createConnectionId,
15-
sendMessage,
16-
receiveMessage,
17-
MessageReceived (..),
18-
19-
-- * Specialized messages
20-
sendAck,
21-
sendRejection,
22-
recvAck,
23-
recvWord32,
24-
sendCloseConnection,
25-
sendCloseEndPoint,
26-
27-
-- * Handshake protocol
28-
handshake,
29-
30-
-- * Re-exported for testing
31-
encodeMessage,
32-
decodeMessage,
33-
)
8+
module Network.Transport.QUIC.Internal.Messaging (
9+
-- * Connections
10+
ServerConnId,
11+
serverSelfConnId,
12+
firstNonReservedServerConnId,
13+
ClientConnId,
14+
createConnectionId,
15+
sendMessage,
16+
receiveMessage,
17+
MessageReceived (..),
18+
19+
-- * Specialized messages
20+
sendAck,
21+
sendRejection,
22+
recvAck,
23+
recvWord32,
24+
sendCloseConnection,
25+
sendCloseEndPoint,
26+
27+
-- * Handshake protocol
28+
handshake,
29+
30+
-- * Re-exported for testing
31+
encodeMessage,
32+
decodeMessage,
33+
)
3434
where
3535

36-
import Control.Exception (catch, displayException, mask, throwIO, try, SomeException)
36+
import Control.Exception (SomeException, catch, displayException, mask, throwIO, try)
37+
import Control.Monad (replicateM)
3738
import Data.Binary (Binary)
3839
import Data.Binary qualified as Binary
3940
import Data.Bits (shiftL, (.|.))
@@ -49,10 +50,11 @@ import Network.Transport.Internal (decodeWord32, encodeWord32)
4950
import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (QUICAddr), decodeQUICAddr)
5051
import System.Timeout (timeout)
5152

52-
-- | Send a message to a remote endpoint ID
53-
--
54-
-- This function is thread-safe; while the data is sending, asynchronous
55-
-- exceptions are masked, to be rethrown after the data is sent.
53+
{- | Send a message to a remote endpoint ID
54+
55+
This function is thread-safe; while the data is sending, asynchronous
56+
exceptions are masked, to be rethrown after the data is sent.
57+
-}
5658
sendMessage ::
5759
Stream ->
5860
ClientConnId ->
@@ -65,10 +67,11 @@ sendMessage stream connId messages =
6567
(encodeMessage connId messages)
6668
)
6769

68-
-- | Receive a message, including its local destination endpoint ID
69-
--
70-
-- This function is thread-safe; while the data is being received, asynchronous
71-
-- exceptions are masked, to be rethrown after the data is sent.
70+
{- | Receive a message, including its local destination endpoint ID
71+
72+
This function is thread-safe; while the data is being received, asynchronous
73+
exceptions are masked, to be rethrown after the data is sent.
74+
-}
7275
receiveMessage ::
7376
Stream ->
7477
IO (Either String MessageReceived)
@@ -81,66 +84,65 @@ receiveMessage stream = mask $ \restore ->
8184
)
8285
`catch` (\(ex :: QUIC.QUICException) -> throwIO ex)
8386

84-
-- | Encode a message.
85-
--
86-
-- The encoding is composed of a header, and the payload.
87-
-- The message header is composed of two 32-bit numbers:
88-
-- The endpoint ID of the destination endpoint, padded to a 32-bit big endian number;
89-
-- The length of the payload, again padded to a 32-bit big endian number
87+
{- | Encode a message.
88+
89+
The encoding is composed of a header, and the payloads.
90+
The message header is composed of:
91+
1. A control byte, to determine how the message should be parsed.
92+
2. A 32-bit word that encodes the endpoint ID of the destination endpoint;
93+
3. A 32-bit word that encodes the number of frames in the message
94+
95+
The payload frames are each prepended with the length of the frame.
96+
-}
9097
encodeMessage ::
9198
ClientConnId ->
9299
[ByteString] ->
93100
[ByteString]
94-
encodeMessage connId =
95-
-- For simplicity, we are keeping the message boundaries, and adding
96-
-- a header for each message.
97-
--
98-
-- We could also merge all messages together, and have a single
99-
-- header, but this requires specifying some message framing
100-
fmap withHeader
101-
where
102-
withHeader message =
103-
BS.concat
104-
[ BS.singleton messageControlByte,
105-
encodeWord32 (fromIntegral connId),
106-
encodeWord32 (fromIntegral (BS.length message)),
107-
message
108-
]
101+
encodeMessage connId messages =
102+
BS.concat
103+
[ BS.singleton messageControlByte
104+
, encodeWord32 (fromIntegral connId)
105+
, encodeWord32 (fromIntegral $ length messages)
106+
]
107+
: [encodeWord32 (fromIntegral $ BS.length message) <> message | message <- messages]
109108

110109
decodeMessage :: (Int -> IO ByteString) -> IO (Either String MessageReceived)
111110
decodeMessage get =
112111
get 1 >>= maybe (pure $ Right StreamClosed) go . flip BS.indexMaybe 0
113-
where
114-
go ctrl
115-
| ctrl == closeEndPointControlByte = pure $ Right CloseEndPoint
116-
| ctrl == closeConnectionControlByte = Right . CloseConnection . fromIntegral <$> getWord32
117-
| ctrl == messageControlByte = do
118-
connId <- getWord32
119-
messageLength <- getWord32
120-
get (fromIntegral messageLength) <&> Right . Message (fromIntegral connId)
121-
| otherwise = pure $ Left $ "Unsupported control byte: " <> show ctrl
122-
getWord32 = get 4 <&> decodeWord32
123-
124-
-- | Wrap a method to fetch bytes, to ensure that we always get exactly the
125-
-- intended number of bytes.
112+
where
113+
go ctrl
114+
| ctrl == closeEndPointControlByte = pure $ Right CloseEndPoint
115+
| ctrl == closeConnectionControlByte = Right . CloseConnection . fromIntegral <$> getWord32
116+
| ctrl == messageControlByte = do
117+
connId <- getWord32
118+
numMessages <- getWord32
119+
messages <- replicateM (fromIntegral numMessages) $ do
120+
getWord32 >>= get . fromIntegral
121+
pure . Right $ Message (fromIntegral connId) messages
122+
| otherwise = pure $ Left $ "Unsupported control byte: " <> show ctrl
123+
getWord32 = get 4 <&> decodeWord32
124+
125+
{- | Wrap a method to fetch bytes, to ensure that we always get exactly the
126+
intended number of bytes.
127+
-}
126128
getAllBytes ::
127129
-- | Function to fetch at most 'n' bytes
128130
(Int -> IO ByteString) ->
129131
-- | Function to fetch exactly 'n' bytes
130132
(Int -> IO ByteString)
131133
getAllBytes get n = go n mempty
132-
where
133-
go 0 !acc = pure $ BS.concat acc
134-
go m !acc =
135-
get m >>= \bytes ->
136-
go
137-
(m - BS.length bytes)
138-
(acc <> [bytes])
134+
where
135+
go 0 !acc = pure $ BS.concat acc
136+
go m !acc =
137+
get m >>= \bytes ->
138+
go
139+
(m - BS.length bytes)
140+
(acc <> [bytes])
139141

140142
data MessageReceived
141143
= Message
142144
{-# UNPACK #-} !ClientConnId
143-
{-# UNPACK #-} !ByteString
145+
{-# UNPACK #-} ![ByteString]
144146
| CloseConnection !ClientConnId
145147
| CloseEndPoint
146148
| StreamClosed
@@ -176,16 +178,17 @@ recvAck stream = do
176178
>>= maybe
177179
(throwIO (AckException "Connection ack not received within acceptable timeframe"))
178180
go
179-
where
180-
go response
181-
| response == ackMessage = pure $ Right ()
182-
| response == rejectMessage = pure $ Left ()
183-
| otherwise = throwIO (AckException "Unexpected ack response")
184-
185-
-- | Receive a 'Word32'
186-
--
187-
-- This function is thread-safe; while the data is being received, asynchronous
188-
-- exceptions are masked, to be rethrown after the data is sent.
181+
where
182+
go response
183+
| response == ackMessage = pure $ Right ()
184+
| response == rejectMessage = pure $ Left ()
185+
| otherwise = throwIO (AckException "Unexpected ack response")
186+
187+
{- | Receive a 'Word32'
188+
189+
This function is thread-safe; while the data is being received, asynchronous
190+
exceptions are masked, to be rethrown after the data is sent.
191+
-}
189192
recvWord32 ::
190193
Stream ->
191194
IO (Either String Word32)
@@ -196,8 +199,9 @@ recvWord32 stream =
196199
)
197200
`catch` (\(ex :: QUIC.QUICException) -> pure $ Left (displayException ex))
198201

199-
-- | We perform some special actions based on a message's control byte.
200-
-- For example, if a client wants to close a connection.
202+
{- | We perform some special actions based on a message's control byte.
203+
For example, if a client wants to close a connection.
204+
-}
201205
type ControlByte = Word8
202206

203207
connectionAcceptedControlByte :: ControlByte
@@ -235,8 +239,9 @@ sendCloseEndPoint stream =
235239
)
236240
)
237241

238-
-- | Handshake protocol that a client, connecting to a remote endpoint,
239-
-- has to perform.
242+
{- | Handshake protocol that a client, connecting to a remote endpoint,
243+
has to perform.
244+
-}
240245

241246
-- TODO: encode server part of the handhake
242247
handshake ::
@@ -253,18 +258,18 @@ handshake (ourAddress, theirAddress) stream =
253258
let encodedPayload = BS.toStrict $ Binary.encode (ourAddress, serverEndPointId)
254259
payloadLength = encodeWord32 $ fromIntegral (BS.length encodedPayload)
255260

256-
try (
257-
QUIC.sendStream
258-
stream
259-
(BS.concat [payloadLength, encodedPayload]))
260-
>>= \case
261-
Left (_exc :: SomeException) -> pure $ Left ()
262-
Right _ ->
263-
264-
-- Server acknowledgement that the handshake is complete
265-
-- means that we cannot send messages until the server
266-
-- is ready for them
267-
recvAck stream
261+
try
262+
( QUIC.sendStream
263+
stream
264+
(BS.concat [payloadLength, encodedPayload])
265+
)
266+
>>= \case
267+
Left (_exc :: SomeException) -> pure $ Left ()
268+
Right _ ->
269+
-- Server acknowledgement that the handshake is complete
270+
-- means that we cannot send messages until the server
271+
-- is ready for them
272+
recvAck stream
268273

269274
-- | Part of the connection ID that is client-allocated.
270275
newtype ClientConnId = ClientConnId Word32

packages/network-transport-quic/test/Test/Network/Transport/QUIC/Internal/Messaging.hs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
{-# LANGUAGE OverloadedStrings #-}
2-
{-# LANGUAGE LambdaCase #-}
32

43
module Test.Network.Transport.QUIC.Internal.Messaging (tests) where
54

6-
import Control.Monad (replicateM)
75
import Control.Monad.IO.Class (MonadIO (liftIO))
86
import Data.ByteString (ByteString)
97
import Data.ByteString qualified as BS
@@ -35,8 +33,8 @@ testMessageEncodingAndDecoding = testProperty "Encoded messages can be decoded"
3533

3634
getBytes <- liftIO $ messageDecoder encoded
3735

38-
decoded <- liftIO $ replicateM (length messages) (decodeMessage getBytes)
39-
(Right . Message endpointId <$> messages) === decoded
36+
decoded <- liftIO $ decodeMessage getBytes
37+
Right (Message endpointId messages) === decoded
4038

4139
messageDecoder :: ByteString -> IO (Int -> IO ByteString)
4240
messageDecoder allBytes = do

0 commit comments

Comments
 (0)