44//!
55//! - Modify HTTP/S requests
66//! - Modify HTTP/S responses
7- //! - Modify websocket messages
7+ //! - Modify WebSocket messages
88//!
99//! ## Features
1010//!
@@ -25,14 +25,16 @@ mod rewind;
2525
2626pub mod certificate_authority;
2727
28+ use futures:: { Sink , SinkExt , Stream , StreamExt } ;
2829use hyper:: { Body , Request , Response , StatusCode , Uri } ;
2930use std:: net:: SocketAddr ;
30- use tokio_tungstenite:: tungstenite:: Message ;
31+ use tokio_tungstenite:: tungstenite:: { self , Message } ;
3132use tracing:: error;
3233
3334pub ( crate ) use rewind:: Rewind ;
3435
3536pub use async_trait;
37+ pub use futures;
3638pub use hyper;
3739#[ cfg( feature = "openssl-ca" ) ]
3840pub use openssl;
@@ -98,7 +100,7 @@ pub enum WebSocketContext {
98100/// Each request/response pair is passed to the same instance of the handler.
99101#[ async_trait:: async_trait]
100102pub trait HttpHandler : Clone + Send + Sync + ' static {
101- /// The handler will be called for each HTTP request. It can either return a modified request,
103+ /// This handler will be called for each HTTP request. It can either return a modified request,
102104 /// or a response. If a request is returned, it will be sent to the upstream server. If a
103105 /// response is returned, it will be sent to the client.
104106 async fn handle_request (
@@ -109,13 +111,13 @@ pub trait HttpHandler: Clone + Send + Sync + 'static {
109111 req. into ( )
110112 }
111113
112- /// The handler will be called for each HTTP response. It can modify a response before it is
114+ /// This handler will be called for each HTTP response. It can modify a response before it is
113115 /// forwarded to the client.
114116 async fn handle_response ( & mut self , _ctx : & HttpContext , res : Response < Body > ) -> Response < Body > {
115117 res
116118 }
117119
118- /// The handler will be called if a proxy request fails. Default response is a 502 Bad Gateway.
120+ /// This handler will be called if a proxy request fails. Default response is a 502 Bad Gateway.
119121 async fn handle_error ( & mut self , _ctx : & HttpContext , err : hyper:: Error ) -> Response < Body > {
120122 error ! ( "Failed to forward request: {}" , err) ;
121123 Response :: builder ( )
@@ -130,12 +132,48 @@ pub trait HttpHandler: Clone + Send + Sync + 'static {
130132 }
131133}
132134
133- /// Handler for websocket messages.
135+ /// Handler for WebSocket messages.
134136///
135- /// Messages sent over the same websocket stream are passed to the same instance of the handler.
137+ /// Messages sent over the same WebSocket Stream are passed to the same instance of the handler.
136138#[ async_trait:: async_trait]
137139pub trait WebSocketHandler : Clone + Send + Sync + ' static {
138- /// The handler will be called for each websocket message. It can return an optional modified
140+ /// This handler is responsible for forwarding WebSocket messages from a Stream to a Sink and
141+ /// recovering from any potential errors.
142+ async fn handle_websocket (
143+ mut self ,
144+ ctx : WebSocketContext ,
145+ mut stream : impl Stream < Item = Result < Message , tungstenite:: Error > > + Unpin + Send + ' static ,
146+ mut sink : impl Sink < Message , Error = tungstenite:: Error > + Unpin + Send + ' static ,
147+ ) {
148+ while let Some ( message) = stream. next ( ) . await {
149+ match message {
150+ Ok ( message) => {
151+ let Some ( message) = self . handle_message ( & ctx, message) . await else {
152+ continue
153+ } ;
154+
155+ match sink. send ( message) . await {
156+ Err ( tungstenite:: Error :: ConnectionClosed ) => ( ) ,
157+ Err ( e) => error ! ( "WebSocket send error: {}" , e) ,
158+ _ => ( ) ,
159+ }
160+ }
161+ Err ( e) => {
162+ error ! ( "WebSocket message error: {}" , e) ;
163+
164+ match sink. send ( Message :: Close ( None ) ) . await {
165+ Err ( tungstenite:: Error :: ConnectionClosed ) => ( ) ,
166+ Err ( e) => error ! ( "WebSocket close error: {}" , e) ,
167+ _ => ( ) ,
168+ } ;
169+
170+ break ;
171+ }
172+ }
173+ }
174+ }
175+
176+ /// This handler will be called for each WebSocket message. It can return an optional modified
139177 /// message. If None is returned the message will not be forwarded.
140178 async fn handle_message (
141179 & mut self ,
0 commit comments