11{-# LANGUAGE BangPatterns #-}
22{-# LANGUAGE DerivingStrategies #-}
33{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4+ {-# LANGUAGE LambdaCase #-}
45{-# LANGUAGE NumericUnderscores #-}
56{-# LANGUAGE ScopedTypeVariables #-}
6- {-# LANGUAGE LambdaCase #-}
77
8- module Network.Transport.QUIC.Internal.Messaging
9- ( -- * Connections
10- ServerConnId ,
11- serverSelfConnId ,
12- firstNonReservedServerConnId ,
13- ClientConnId ,
14- createConnectionId ,
15- sendMessage ,
16- receiveMessage ,
17- MessageReceived (.. ),
18-
19- -- * Specialized messages
20- sendAck ,
21- sendRejection ,
22- recvAck ,
23- recvWord32 ,
24- sendCloseConnection ,
25- sendCloseEndPoint ,
26-
27- -- * Handshake protocol
28- handshake ,
29-
30- -- * Re-exported for testing
31- encodeMessage ,
32- decodeMessage ,
33- )
8+ module Network.Transport.QUIC.Internal.Messaging (
9+ -- * Connections
10+ ServerConnId ,
11+ serverSelfConnId ,
12+ firstNonReservedServerConnId ,
13+ ClientConnId ,
14+ createConnectionId ,
15+ sendMessage ,
16+ receiveMessage ,
17+ MessageReceived (.. ),
18+
19+ -- * Specialized messages
20+ sendAck ,
21+ sendRejection ,
22+ recvAck ,
23+ recvWord32 ,
24+ sendCloseConnection ,
25+ sendCloseEndPoint ,
26+
27+ -- * Handshake protocol
28+ handshake ,
29+
30+ -- * Re-exported for testing
31+ encodeMessage ,
32+ decodeMessage ,
33+ )
3434where
3535
36- import Control.Exception (catch , displayException , mask , throwIO , try , SomeException )
36+ import Control.Exception (SomeException , catch , displayException , mask , throwIO , try )
37+ import Control.Monad (replicateM )
3738import Data.Binary (Binary )
3839import Data.Binary qualified as Binary
3940import Data.Bits (shiftL , (.|.) )
@@ -49,10 +50,11 @@ import Network.Transport.Internal (decodeWord32, encodeWord32)
4950import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (QUICAddr ), decodeQUICAddr )
5051import System.Timeout (timeout )
5152
52- -- | Send a message to a remote endpoint ID
53- --
54- -- This function is thread-safe; while the data is sending, asynchronous
55- -- exceptions are masked, to be rethrown after the data is sent.
53+ {- | Send a message to a remote endpoint ID
54+
55+ This function is thread-safe; while the data is sending, asynchronous
56+ exceptions are masked, to be rethrown after the data is sent.
57+ -}
5658sendMessage ::
5759 Stream ->
5860 ClientConnId ->
@@ -65,10 +67,11 @@ sendMessage stream connId messages =
6567 (encodeMessage connId messages)
6668 )
6769
68- -- | Receive a message, including its local destination endpoint ID
69- --
70- -- This function is thread-safe; while the data is being received, asynchronous
71- -- exceptions are masked, to be rethrown after the data is sent.
70+ {- | Receive a message, including its local destination endpoint ID
71+
72+ This function is thread-safe; while the data is being received, asynchronous
73+ exceptions are masked, to be rethrown after the data is sent.
74+ -}
7275receiveMessage ::
7376 Stream ->
7477 IO (Either String MessageReceived )
@@ -81,66 +84,65 @@ receiveMessage stream = mask $ \restore ->
8184 )
8285 `catch` (\ (ex :: QUIC. QUICException ) -> throwIO ex)
8386
84- -- | Encode a message.
85- --
86- -- The encoding is composed of a header, and the payload.
87- -- The message header is composed of two 32-bit numbers:
88- -- The endpoint ID of the destination endpoint, padded to a 32-bit big endian number;
89- -- The length of the payload, again padded to a 32-bit big endian number
87+ {- | Encode a message.
88+
89+ The encoding is composed of a header, and the payloads.
90+ The message header is composed of:
91+ 1. A control byte, to determine how the message should be parsed.
92+ 2. A 32-bit word that encodes the endpoint ID of the destination endpoint;
93+ 3. A 32-bit word that encodes the number of frames in the message
94+
95+ The payload frames are each prepended with the length of the frame.
96+ -}
9097encodeMessage ::
9198 ClientConnId ->
9299 [ByteString ] ->
93100 [ByteString ]
94- encodeMessage connId =
95- -- For simplicity, we are keeping the message boundaries, and adding
96- -- a header for each message.
97- --
98- -- We could also merge all messages together, and have a single
99- -- header, but this requires specifying some message framing
100- fmap withHeader
101- where
102- withHeader message =
103- BS. concat
104- [ BS. singleton messageControlByte,
105- encodeWord32 (fromIntegral connId),
106- encodeWord32 (fromIntegral (BS. length message)),
107- message
108- ]
101+ encodeMessage connId messages =
102+ BS. concat
103+ [ BS. singleton messageControlByte
104+ , encodeWord32 (fromIntegral connId)
105+ , encodeWord32 (fromIntegral $ length messages)
106+ ]
107+ : [encodeWord32 (fromIntegral $ BS. length message) <> message | message <- messages]
109108
110109decodeMessage :: (Int -> IO ByteString ) -> IO (Either String MessageReceived )
111110decodeMessage get =
112111 get 1 >>= maybe (pure $ Right StreamClosed ) go . flip BS. indexMaybe 0
113- where
114- go ctrl
115- | ctrl == closeEndPointControlByte = pure $ Right CloseEndPoint
116- | ctrl == closeConnectionControlByte = Right . CloseConnection . fromIntegral <$> getWord32
117- | ctrl == messageControlByte = do
118- connId <- getWord32
119- messageLength <- getWord32
120- get (fromIntegral messageLength) <&> Right . Message (fromIntegral connId)
121- | otherwise = pure $ Left $ " Unsupported control byte: " <> show ctrl
122- getWord32 = get 4 <&> decodeWord32
123-
124- -- | Wrap a method to fetch bytes, to ensure that we always get exactly the
125- -- intended number of bytes.
112+ where
113+ go ctrl
114+ | ctrl == closeEndPointControlByte = pure $ Right CloseEndPoint
115+ | ctrl == closeConnectionControlByte = Right . CloseConnection . fromIntegral <$> getWord32
116+ | ctrl == messageControlByte = do
117+ connId <- getWord32
118+ numMessages <- getWord32
119+ messages <- replicateM (fromIntegral numMessages) $ do
120+ getWord32 >>= get . fromIntegral
121+ pure . Right $ Message (fromIntegral connId) messages
122+ | otherwise = pure $ Left $ " Unsupported control byte: " <> show ctrl
123+ getWord32 = get 4 <&> decodeWord32
124+
125+ {- | Wrap a method to fetch bytes, to ensure that we always get exactly the
126+ intended number of bytes.
127+ -}
126128getAllBytes ::
127129 -- | Function to fetch at most 'n' bytes
128130 (Int -> IO ByteString ) ->
129131 -- | Function to fetch exactly 'n' bytes
130132 (Int -> IO ByteString )
131133getAllBytes get n = go n mempty
132- where
133- go 0 ! acc = pure $ BS. concat acc
134- go m ! acc =
135- get m >>= \ bytes ->
136- go
137- (m - BS. length bytes)
138- (acc <> [bytes])
134+ where
135+ go 0 ! acc = pure $ BS. concat acc
136+ go m ! acc =
137+ get m >>= \ bytes ->
138+ go
139+ (m - BS. length bytes)
140+ (acc <> [bytes])
139141
140142data MessageReceived
141143 = Message
142144 {- # UNPACK #-} !ClientConnId
143- {- # UNPACK #-} !ByteString
145+ {- # UNPACK #-} ![ ByteString ]
144146 | CloseConnection ! ClientConnId
145147 | CloseEndPoint
146148 | StreamClosed
@@ -176,16 +178,17 @@ recvAck stream = do
176178 >>= maybe
177179 (throwIO (AckException " Connection ack not received within acceptable timeframe" ))
178180 go
179- where
180- go response
181- | response == ackMessage = pure $ Right ()
182- | response == rejectMessage = pure $ Left ()
183- | otherwise = throwIO (AckException " Unexpected ack response" )
184-
185- -- | Receive a 'Word32'
186- --
187- -- This function is thread-safe; while the data is being received, asynchronous
188- -- exceptions are masked, to be rethrown after the data is sent.
181+ where
182+ go response
183+ | response == ackMessage = pure $ Right ()
184+ | response == rejectMessage = pure $ Left ()
185+ | otherwise = throwIO (AckException " Unexpected ack response" )
186+
187+ {- | Receive a 'Word32'
188+
189+ This function is thread-safe; while the data is being received, asynchronous
190+ exceptions are masked, to be rethrown after the data is sent.
191+ -}
189192recvWord32 ::
190193 Stream ->
191194 IO (Either String Word32 )
@@ -196,8 +199,9 @@ recvWord32 stream =
196199 )
197200 `catch` (\ (ex :: QUIC. QUICException ) -> pure $ Left (displayException ex))
198201
199- -- | We perform some special actions based on a message's control byte.
200- -- For example, if a client wants to close a connection.
202+ {- | We perform some special actions based on a message's control byte.
203+ For example, if a client wants to close a connection.
204+ -}
201205type ControlByte = Word8
202206
203207connectionAcceptedControlByte :: ControlByte
@@ -235,8 +239,9 @@ sendCloseEndPoint stream =
235239 )
236240 )
237241
238- -- | Handshake protocol that a client, connecting to a remote endpoint,
239- -- has to perform.
242+ {- | Handshake protocol that a client, connecting to a remote endpoint,
243+ has to perform.
244+ -}
240245
241246-- TODO: encode server part of the handhake
242247handshake ::
@@ -253,18 +258,18 @@ handshake (ourAddress, theirAddress) stream =
253258 let encodedPayload = BS. toStrict $ Binary. encode (ourAddress, serverEndPointId)
254259 payloadLength = encodeWord32 $ fromIntegral (BS. length encodedPayload)
255260
256- try (
257- QUIC. sendStream
258- stream
259- (BS. concat [payloadLength, encodedPayload]) )
260- >>= \ case
261- Left (_exc :: SomeException ) -> pure $ Left ()
262- Right _ ->
263-
264- -- Server acknowledgement that the handshake is complete
265- -- means that we cannot send messages until the server
266- -- is ready for them
267- recvAck stream
261+ try
262+ ( QUIC. sendStream
263+ stream
264+ (BS. concat [payloadLength, encodedPayload])
265+ )
266+ >>= \ case
267+ Left (_exc :: SomeException ) -> pure $ Left ()
268+ Right _ ->
269+ -- Server acknowledgement that the handshake is complete
270+ -- means that we cannot send messages until the server
271+ -- is ready for them
272+ recvAck stream
268273
269274-- | Part of the connection ID that is client-allocated.
270275newtype ClientConnId = ClientConnId Word32
0 commit comments