Skip to content

Commit 4165da0

Browse files
committed
feat: allow http handler to control whether to intercept CONNECT request
1 parent 02193f3 commit 4165da0

File tree

6 files changed

+206
-85
lines changed

6 files changed

+206
-85
lines changed

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ pub trait HttpHandler: Clone + Send + Sync + 'static {
113113
async fn handle_response(&mut self, _ctx: &HttpContext, res: Response<Body>) -> Response<Body> {
114114
res
115115
}
116+
117+
/// Whether a CONNECT request should be intercepted. Defaults to `true` for all requests.
118+
async fn should_intercept(&mut self, _req: &Request<Body>) -> bool {
119+
true
120+
}
116121
}
117122

118123
/// Handler for websocket messages.

src/proxy/internal.rs

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ where
115115
}
116116
}
117117

118-
fn process_connect(self, mut req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
118+
fn process_connect(mut self, mut req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
119119
match req.uri().authority().cloned() {
120120
Some(authority) => {
121121
let span = info_span!("process_connect");
@@ -136,60 +136,66 @@ where
136136
bytes::Bytes::copy_from_slice(buffer[..bytes_read].as_ref()),
137137
);
138138

139-
if buffer == *b"GET " {
140-
if let Err(e) =
141-
self.serve_stream(upgraded, Scheme::HTTP, authority).await
142-
{
143-
error!("Websocket connect error: {}", e);
144-
}
145-
} else if buffer[..2] == *b"\x16\x03" {
146-
let server_config = self
147-
.ca
148-
.gen_server_config(&authority)
149-
.instrument(info_span!("gen_server_config"))
150-
.await;
151-
152-
let stream =
153-
match TlsAcceptor::from(server_config).accept(upgraded).await {
139+
if self.http_handler.should_intercept(&req).await {
140+
if buffer == *b"GET " {
141+
if let Err(e) =
142+
self.serve_stream(upgraded, Scheme::HTTP, authority).await
143+
{
144+
error!("Websocket connect error: {}", e);
145+
}
146+
147+
return;
148+
} else if buffer[..2] == *b"\x16\x03" {
149+
let server_config = self
150+
.ca
151+
.gen_server_config(&authority)
152+
.instrument(info_span!("gen_server_config"))
153+
.await;
154+
155+
let stream = match TlsAcceptor::from(server_config)
156+
.accept(upgraded)
157+
.await
158+
{
154159
Ok(stream) => stream,
155160
Err(e) => {
156161
error!("Failed to establish TLS connection: {}", e);
157162
return;
158163
}
159164
};
160165

161-
if let Err(e) =
162-
self.serve_stream(stream, Scheme::HTTPS, authority).await
163-
{
164-
if !e.to_string().starts_with("error shutting down connection")
166+
if let Err(e) =
167+
self.serve_stream(stream, Scheme::HTTPS, authority).await
165168
{
166-
error!("HTTPS connect error: {}", e);
167-
}
168-
}
169-
} else {
170-
warn!(
171-
"Unknown protocol, read '{:02X?}' from upgraded connection",
172-
&buffer[..bytes_read]
173-
);
174-
175-
let mut server = match TcpStream::connect(authority.as_ref()).await
176-
{
177-
Ok(server) => server,
178-
Err(e) => {
179-
error!("Failed to connect to {}: {}", authority, e);
180-
return;
169+
if !e
170+
.to_string()
171+
.starts_with("error shutting down connection")
172+
{
173+
error!("HTTPS connect error: {}", e);
174+
}
181175
}
182-
};
183-
184-
if let Err(e) =
185-
tokio::io::copy_bidirectional(&mut upgraded, &mut server).await
186-
{
187-
error!(
188-
"Failed to tunnel unknown protocol to {}: {}",
189-
authority, e
176+
177+
return;
178+
} else {
179+
warn!(
180+
"Unknown protocol, read '{:02X?}' from upgraded connection",
181+
&buffer[..bytes_read]
190182
);
191183
}
192184
}
185+
186+
let mut server = match TcpStream::connect(authority.as_ref()).await {
187+
Ok(server) => server,
188+
Err(e) => {
189+
error!("Failed to connect to {}: {}", authority, e);
190+
return;
191+
}
192+
};
193+
194+
if let Err(e) =
195+
tokio::io::copy_bidirectional(&mut upgraded, &mut server).await
196+
{
197+
error!("Failed to tunnel to {}: {}", authority, e);
198+
}
193199
}
194200
Err(e) => error!("Upgrade error: {}", e),
195201
};

tests/common/mod.rs

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ pub async fn start_https_server(
113113
Ok((addr, tx))
114114
}
115115

116+
pub fn http_client() -> Client<HttpConnector> {
117+
Client::new()
118+
}
119+
120+
pub fn plain_websocket_connector() -> tokio_tungstenite::Connector {
121+
tokio_tungstenite::Connector::Plain
122+
}
123+
116124
fn rustls_client_config() -> rustls::ClientConfig {
117125
let mut roots = rustls::RootCertStore::empty();
118126

@@ -174,37 +182,57 @@ pub fn native_tls_client() -> Client<hyper_tls::HttpsConnector<HttpConnector>> {
174182
Client::builder().build(https)
175183
}
176184

177-
type TestHandlers = (TestHttpHandler, TestWebSocketHandler);
178-
179185
pub fn start_proxy<C>(
180186
ca: impl CertificateAuthority,
181187
client: Client<C>,
182188
websocket_connector: tokio_tungstenite::Connector,
183-
) -> Result<(SocketAddr, TestHandlers, Sender<()>), Box<dyn std::error::Error>>
189+
) -> Result<(SocketAddr, TestHandler, Sender<()>), Box<dyn std::error::Error>>
190+
where
191+
C: Connect + Clone + Send + Sync + 'static,
192+
{
193+
_start_proxy(ca, client, websocket_connector, true)
194+
}
195+
196+
pub fn start_proxy_without_intercept<C>(
197+
ca: impl CertificateAuthority,
198+
client: Client<C>,
199+
websocket_connector: tokio_tungstenite::Connector,
200+
) -> Result<(SocketAddr, TestHandler, Sender<()>), Box<dyn std::error::Error>>
201+
where
202+
C: Connect + Clone + Send + Sync + 'static,
203+
{
204+
_start_proxy(ca, client, websocket_connector, false)
205+
}
206+
207+
fn _start_proxy<C>(
208+
ca: impl CertificateAuthority,
209+
client: Client<C>,
210+
websocket_connector: tokio_tungstenite::Connector,
211+
should_intercept: bool,
212+
) -> Result<(SocketAddr, TestHandler, Sender<()>), Box<dyn std::error::Error>>
184213
where
185214
C: Connect + Clone + Send + Sync + 'static,
186215
{
187216
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?;
188217
let addr = listener.local_addr()?;
189218
let (tx, rx) = tokio::sync::oneshot::channel();
190219

191-
let http_handler = TestHttpHandler::new();
192-
let websocket_handler = TestWebSocketHandler::new();
220+
let handler = TestHandler::new(should_intercept);
193221

194222
let proxy = Proxy::builder()
195223
.with_listener(listener)
196224
.with_client(client)
197225
.with_ca(ca)
198-
.with_http_handler(http_handler.clone())
199-
.with_websocket_handler(websocket_handler.clone())
226+
.with_http_handler(handler.clone())
227+
.with_websocket_handler(handler.clone())
200228
.with_websocket_connector(websocket_connector)
201229
.build();
202230

203231
tokio::spawn(proxy.start(async {
204232
rx.await.unwrap_or_default();
205233
}));
206234

207-
Ok((addr, (http_handler, websocket_handler), tx))
235+
Ok((addr, handler, tx))
208236
}
209237

210238
pub fn start_noop_proxy(
@@ -242,22 +270,26 @@ pub fn build_client(proxy: &str) -> reqwest::Client {
242270
}
243271

244272
#[derive(Clone)]
245-
pub struct TestHttpHandler {
273+
pub struct TestHandler {
246274
pub request_counter: Arc<AtomicUsize>,
247275
pub response_counter: Arc<AtomicUsize>,
276+
pub message_counter: Arc<AtomicUsize>,
277+
pub should_intercept: bool,
248278
}
249279

250-
impl TestHttpHandler {
251-
pub fn new() -> Self {
280+
impl TestHandler {
281+
pub fn new(should_intercept: bool) -> Self {
252282
Self {
253283
request_counter: Arc::new(AtomicUsize::new(0)),
254284
response_counter: Arc::new(AtomicUsize::new(0)),
285+
message_counter: Arc::new(AtomicUsize::new(0)),
286+
should_intercept,
255287
}
256288
}
257289
}
258290

259291
#[async_trait]
260-
impl HttpHandler for TestHttpHandler {
292+
impl HttpHandler for TestHandler {
261293
async fn handle_request(
262294
&mut self,
263295
_ctx: &HttpContext,
@@ -272,23 +304,14 @@ impl HttpHandler for TestHttpHandler {
272304
self.response_counter.fetch_add(1, Ordering::Relaxed);
273305
decode_response(res).unwrap()
274306
}
275-
}
276307

277-
#[derive(Clone)]
278-
pub struct TestWebSocketHandler {
279-
pub message_counter: Arc<AtomicUsize>,
280-
}
281-
282-
impl TestWebSocketHandler {
283-
pub fn new() -> Self {
284-
Self {
285-
message_counter: Arc::new(AtomicUsize::new(0)),
286-
}
308+
async fn should_intercept(&mut self, _req: &Request<Body>) -> bool {
309+
self.should_intercept
287310
}
288311
}
289312

290313
#[async_trait]
291-
impl WebSocketHandler for TestWebSocketHandler {
314+
impl WebSocketHandler for TestHandler {
292315
async fn handle_message(&mut self, _ctx: &WebSocketContext, msg: Message) -> Option<Message> {
293316
self.message_counter.fetch_add(1, Ordering::Relaxed);
294317
Some(msg)

tests/openssl_ca.rs

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ fn build_ca() -> OpensslAuthority {
1818

1919
#[tokio::test]
2020
async fn https_rustls() {
21-
let (proxy_addr, (http_handler, _), stop_proxy) = common::start_proxy(
21+
let (proxy_addr, handler, stop_proxy) = common::start_proxy(
2222
build_ca(),
2323
common::rustls_client(),
2424
common::rustls_websocket_connector(),
@@ -35,16 +35,16 @@ async fn https_rustls() {
3535
.unwrap();
3636

3737
assert_eq!(res.status(), 200);
38-
assert_eq!(http_handler.request_counter.load(Ordering::Relaxed), 2);
39-
assert_eq!(http_handler.response_counter.load(Ordering::Relaxed), 1);
38+
assert_eq!(handler.request_counter.load(Ordering::Relaxed), 2);
39+
assert_eq!(handler.response_counter.load(Ordering::Relaxed), 1);
4040

4141
stop_server.send(()).unwrap();
4242
stop_proxy.send(()).unwrap();
4343
}
4444

4545
#[tokio::test]
4646
async fn https_native_tls() {
47-
let (proxy_addr, (http_handler, _), stop_proxy) = common::start_proxy(
47+
let (proxy_addr, handler, stop_proxy) = common::start_proxy(
4848
build_ca(),
4949
common::native_tls_client(),
5050
common::native_tls_websocket_connector(),
@@ -61,8 +61,34 @@ async fn https_native_tls() {
6161
.unwrap();
6262

6363
assert_eq!(res.status(), 200);
64-
assert_eq!(http_handler.request_counter.load(Ordering::Relaxed), 2);
65-
assert_eq!(http_handler.response_counter.load(Ordering::Relaxed), 1);
64+
assert_eq!(handler.request_counter.load(Ordering::Relaxed), 2);
65+
assert_eq!(handler.response_counter.load(Ordering::Relaxed), 1);
66+
67+
stop_server.send(()).unwrap();
68+
stop_proxy.send(()).unwrap();
69+
}
70+
71+
#[tokio::test]
72+
async fn without_intercept() {
73+
let (proxy_addr, handler, stop_proxy) = common::start_proxy_without_intercept(
74+
build_ca(),
75+
common::http_client(),
76+
common::plain_websocket_connector(),
77+
)
78+
.unwrap();
79+
80+
let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap();
81+
let client = common::build_client(&proxy_addr.to_string());
82+
83+
let res = client
84+
.get(format!("https://localhost:{}/hello", server_addr.port()))
85+
.send()
86+
.await
87+
.unwrap();
88+
89+
assert_eq!(res.status(), 200);
90+
assert_eq!(handler.request_counter.load(Ordering::Relaxed), 1);
91+
assert_eq!(handler.response_counter.load(Ordering::Relaxed), 0);
6692

6793
stop_server.send(()).unwrap();
6894
stop_proxy.send(()).unwrap();

0 commit comments

Comments
 (0)