Skip to content

Commit a8f97f0

Browse files
committed
feat: pass stream and sink to websocket handler
1 parent cebec01 commit a8f97f0

File tree

4 files changed

+56
-46
lines changed

4 files changed

+56
-46
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Hudsucker is a MITM HTTP/S proxy written in Rust that allows you to:
88

99
- Modify HTTP/S requests
1010
- Modify HTTP/S responses
11-
- Modify websocket messages
11+
- Modify WebSocket messages
1212

1313
## Features
1414

src/lib.rs

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
//!
55
//! - Modify HTTP/S requests
66
//! - Modify HTTP/S responses
7-
//! - Modify websocket messages
7+
//! - Modify WebSocket messages
88
//!
99
//! ## Features
1010
//!
@@ -25,14 +25,16 @@ mod rewind;
2525

2626
pub mod certificate_authority;
2727

28+
use futures::{Sink, SinkExt, Stream, StreamExt};
2829
use hyper::{Body, Request, Response, StatusCode, Uri};
2930
use std::net::SocketAddr;
30-
use tokio_tungstenite::tungstenite::Message;
31+
use tokio_tungstenite::tungstenite::{self, Message};
3132
use tracing::error;
3233

3334
pub(crate) use rewind::Rewind;
3435

3536
pub use async_trait;
37+
pub use futures;
3638
pub use hyper;
3739
#[cfg(feature = "openssl-ca")]
3840
pub use openssl;
@@ -98,7 +100,7 @@ pub enum WebSocketContext {
98100
/// Each request/response pair is passed to the same instance of the handler.
99101
#[async_trait::async_trait]
100102
pub trait HttpHandler: Clone + Send + Sync + 'static {
101-
/// The handler will be called for each HTTP request. It can either return a modified request,
103+
/// This handler will be called for each HTTP request. It can either return a modified request,
102104
/// or a response. If a request is returned, it will be sent to the upstream server. If a
103105
/// response is returned, it will be sent to the client.
104106
async fn handle_request(
@@ -109,13 +111,13 @@ pub trait HttpHandler: Clone + Send + Sync + 'static {
109111
req.into()
110112
}
111113

112-
/// The handler will be called for each HTTP response. It can modify a response before it is
114+
/// This handler will be called for each HTTP response. It can modify a response before it is
113115
/// forwarded to the client.
114116
async fn handle_response(&mut self, _ctx: &HttpContext, res: Response<Body>) -> Response<Body> {
115117
res
116118
}
117119

118-
/// The handler will be called if a proxy request fails. Default response is a 502 Bad Gateway.
120+
/// This handler will be called if a proxy request fails. Default response is a 502 Bad Gateway.
119121
async fn handle_error(&mut self, _ctx: &HttpContext, err: hyper::Error) -> Response<Body> {
120122
error!("Failed to forward request: {}", err);
121123
Response::builder()
@@ -130,12 +132,48 @@ pub trait HttpHandler: Clone + Send + Sync + 'static {
130132
}
131133
}
132134

133-
/// Handler for websocket messages.
135+
/// Handler for WebSocket messages.
134136
///
135-
/// Messages sent over the same websocket stream are passed to the same instance of the handler.
137+
/// Messages sent over the same WebSocket Stream are passed to the same instance of the handler.
136138
#[async_trait::async_trait]
137139
pub trait WebSocketHandler: Clone + Send + Sync + 'static {
138-
/// The handler will be called for each websocket message. It can return an optional modified
140+
/// This handler is responsible for forwarding WebSocket messages from a Stream to a Sink and
141+
/// recovering from any potential errors.
142+
async fn handle_websocket(
143+
mut self,
144+
ctx: WebSocketContext,
145+
mut stream: impl Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send + 'static,
146+
mut sink: impl Sink<Message, Error = tungstenite::Error> + Unpin + Send + 'static,
147+
) {
148+
while let Some(message) = stream.next().await {
149+
match message {
150+
Ok(message) => {
151+
let Some(message) = self.handle_message(&ctx, message).await else {
152+
continue
153+
};
154+
155+
match sink.send(message).await {
156+
Err(tungstenite::Error::ConnectionClosed) => (),
157+
Err(e) => error!("WebSocket send error: {}", e),
158+
_ => (),
159+
}
160+
}
161+
Err(e) => {
162+
error!("WebSocket message error: {}", e);
163+
164+
match sink.send(Message::Close(None)).await {
165+
Err(tungstenite::Error::ConnectionClosed) => (),
166+
Err(e) => error!("WebSocket close error: {}", e),
167+
_ => (),
168+
};
169+
170+
break;
171+
}
172+
}
173+
}
174+
}
175+
176+
/// This handler will be called for each WebSocket message. It can return an optional modified
139177
/// message. If None is returned the message will not be forwarded.
140178
async fn handle_message(
141179
&mut self,

src/noop.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{HttpHandler, WebSocketHandler};
22

33
/// A No-op handler.
44
///
5-
/// When using this handler, HTTP requests and responses and websocket messages will not be
5+
/// When using this handler, HTTP requests and responses and WebSocket messages will not be
66
/// modified.
77
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
88
pub struct NoopHandler(());

src/proxy/internal.rs

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
certificate_authority::CertificateAuthority, HttpContext, HttpHandler, RequestOrResponse,
33
Rewind, WebSocketContext, WebSocketHandler,
44
};
5-
use futures::{Sink, SinkExt, Stream, StreamExt};
5+
use futures::{Sink, Stream, StreamExt};
66
use http::uri::{Authority, Scheme};
77
use hyper::{
88
client::connect::Connect, header::Entry, server::conn::Http, service::service_fn,
@@ -153,7 +153,7 @@ where
153153
if let Err(e) =
154154
self.serve_stream(upgraded, Scheme::HTTP, authority).await
155155
{
156-
error!("Websocket connect error: {}", e);
156+
error!("WebSocket connect error: {}", e);
157157
}
158158

159159
return;
@@ -252,11 +252,11 @@ where
252252
match websocket.await {
253253
Ok(ws) => {
254254
if let Err(e) = self.handle_websocket(ws, req).await {
255-
error!("Failed to handle websocket: {}", e);
255+
error!("Failed to handle WebSocket: {}", e);
256256
}
257257
}
258258
Err(e) => {
259-
error!("Failed to upgrade to websocket: {}", e);
259+
error!("Failed to upgrade to WebSocket: {}", e);
260260
}
261261
}
262262
};
@@ -350,41 +350,13 @@ where
350350
}
351351

352352
fn spawn_message_forwarder(
353-
mut stream: impl Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send + 'static,
354-
mut sink: impl Sink<Message, Error = tungstenite::Error> + Unpin + Send + 'static,
355-
mut handler: impl WebSocketHandler,
353+
stream: impl Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send + 'static,
354+
sink: impl Sink<Message, Error = tungstenite::Error> + Unpin + Send + 'static,
355+
handler: impl WebSocketHandler,
356356
ctx: WebSocketContext,
357357
) {
358358
let span = info_span!("message_forwarder", context = ?ctx);
359-
let fut = async move {
360-
while let Some(message) = stream.next().await {
361-
match message {
362-
Ok(message) => {
363-
let Some(message) = handler.handle_message(&ctx, message).await else {
364-
continue
365-
};
366-
367-
match sink.send(message).await {
368-
Err(tungstenite::Error::ConnectionClosed) => (),
369-
Err(e) => error!("Websocket send error: {}", e),
370-
_ => (),
371-
}
372-
}
373-
Err(e) => {
374-
error!("Websocket message error: {}", e);
375-
376-
match sink.send(Message::Close(None)).await {
377-
Err(tungstenite::Error::ConnectionClosed) => (),
378-
Err(e) => error!("Websocket close error: {}", e),
379-
_ => (),
380-
};
381-
382-
break;
383-
}
384-
}
385-
}
386-
};
387-
359+
let fut = handler.handle_websocket(ctx, stream, sink);
388360
spawn_with_trace(fut, span);
389361
}
390362

0 commit comments

Comments
 (0)