Skip to content

Commit ee82583

Browse files
committed
more tests
1 parent c8d1ed6 commit ee82583

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

pgdog/src/frontend/client/query_engine/test/rewrite_offset.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::frontend::router::parser::rewrite::statement::{
22
offset::OffsetPlan, plan::RewriteResult,
33
};
4+
use crate::frontend::router::parser::route::{Route, Shard, ShardWithPriority};
45
use crate::frontend::router::parser::Limit;
56

67
use super::prelude::*;
@@ -21,6 +22,16 @@ async fn run_test(messages: Vec<ProtocolMessage>) -> Option<OffsetPlan> {
2122
}
2223
}
2324

25+
fn cross_shard_route() -> Route {
26+
Route::select(
27+
ShardWithPriority::new_table(Shard::All),
28+
vec![],
29+
Default::default(),
30+
Limit::default(),
31+
None,
32+
)
33+
}
34+
2435
#[tokio::test]
2536
async fn test_offset_limit_literals() {
2637
let offset = run_test(vec![ProtocolMessage::Query(Query::new(
@@ -104,3 +115,130 @@ async fn test_offset_limit_no_select() {
104115

105116
assert!(offset.is_none());
106117
}
118+
119+
#[tokio::test]
120+
async fn test_offset_with_unique_id_simple() {
121+
unsafe {
122+
std::env::set_var("NODE_ID", "pgdog-1");
123+
}
124+
let sql = "SELECT pgdog.unique_id() FROM test LIMIT 10 OFFSET 5";
125+
let mut client = test_sharded_client();
126+
client.client_request = ClientRequest::from(vec![ProtocolMessage::Query(Query::new(sql))]);
127+
128+
let mut engine = QueryEngine::from_client(&client).unwrap();
129+
let mut context = QueryEngineContext::new(&mut client);
130+
131+
engine.parse_and_rewrite(&mut context).await.unwrap();
132+
133+
// After parse_and_rewrite, the Query message should have unique_id replaced.
134+
let rewritten_sql = match &context.client_request.messages[0] {
135+
ProtocolMessage::Query(q) => q.query().to_owned(),
136+
_ => panic!("expected Query"),
137+
};
138+
assert!(
139+
!rewritten_sql.contains("pgdog.unique_id"),
140+
"unique_id should be replaced: {rewritten_sql}"
141+
);
142+
assert!(
143+
rewritten_sql.contains("::bigint"),
144+
"should have bigint cast: {rewritten_sql}"
145+
);
146+
147+
// apply_after_parser with a cross-shard route.
148+
context.client_request.route = Some(cross_shard_route());
149+
context
150+
.rewrite_result
151+
.as_ref()
152+
.unwrap()
153+
.apply_after_parser(context.client_request)
154+
.unwrap();
155+
156+
let final_sql = match &context.client_request.messages[0] {
157+
ProtocolMessage::Query(q) => q.query().to_owned(),
158+
_ => panic!("expected Query"),
159+
};
160+
161+
// unique_id rewrite must survive.
162+
assert!(
163+
!final_sql.contains("pgdog.unique_id"),
164+
"unique_id rewrite must survive apply_after_parser: {final_sql}"
165+
);
166+
assert!(
167+
final_sql.contains("::bigint"),
168+
"bigint cast must survive: {final_sql}"
169+
);
170+
// LIMIT/OFFSET must be rewritten for cross-shard.
171+
assert!(
172+
final_sql.contains("LIMIT 15"),
173+
"LIMIT should be 10+5=15: {final_sql}"
174+
);
175+
assert!(
176+
final_sql.contains("OFFSET 0"),
177+
"OFFSET should be 0: {final_sql}"
178+
);
179+
}
180+
181+
#[tokio::test]
182+
async fn test_offset_with_unique_id_extended() {
183+
unsafe {
184+
std::env::set_var("NODE_ID", "pgdog-1");
185+
}
186+
let sql = "SELECT pgdog.unique_id(), $1 FROM test LIMIT $2 OFFSET $3";
187+
let mut client = test_sharded_client();
188+
client.client_request = ClientRequest::from(vec![
189+
ProtocolMessage::Parse(Parse::new_anonymous(sql)),
190+
ProtocolMessage::Bind(Bind::new_params(
191+
"",
192+
&[
193+
Parameter::new(b"hello"),
194+
Parameter::new(b"10"),
195+
Parameter::new(b"5"),
196+
],
197+
)),
198+
ProtocolMessage::Execute(Execute::new()),
199+
ProtocolMessage::Sync(Sync),
200+
]);
201+
202+
let mut engine = QueryEngine::from_client(&client).unwrap();
203+
let mut context = QueryEngineContext::new(&mut client);
204+
205+
engine.parse_and_rewrite(&mut context).await.unwrap();
206+
207+
// After parse_and_rewrite, Parse should have unique_id rewritten to $4::bigint.
208+
let rewritten_sql = match &context.client_request.messages[0] {
209+
ProtocolMessage::Parse(p) => p.query().to_owned(),
210+
_ => panic!("expected Parse"),
211+
};
212+
assert_eq!(
213+
rewritten_sql,
214+
"SELECT $4::bigint, $1 FROM test LIMIT $2 OFFSET $3"
215+
);
216+
217+
// apply_after_parser with cross-shard route should only rewrite Bind params.
218+
context.client_request.route = Some(cross_shard_route());
219+
context
220+
.rewrite_result
221+
.as_ref()
222+
.unwrap()
223+
.apply_after_parser(context.client_request)
224+
.unwrap();
225+
226+
// SQL unchanged (all limit/offset are params).
227+
let final_sql = match &context.client_request.messages[0] {
228+
ProtocolMessage::Parse(p) => p.query().to_owned(),
229+
_ => panic!("expected Parse"),
230+
};
231+
assert_eq!(
232+
final_sql, "SELECT $4::bigint, $1 FROM test LIMIT $2 OFFSET $3",
233+
"SQL must be unchanged for all-param case"
234+
);
235+
236+
// Bind params: $1=hello unchanged, $2=limit rewritten to 15, $3=offset rewritten to 0.
237+
if let ProtocolMessage::Bind(bind) = &context.client_request.messages[1] {
238+
assert_eq!(bind.params_raw()[0].data.as_ref(), b"hello");
239+
assert_eq!(bind.params_raw()[1].data.as_ref(), b"15");
240+
assert_eq!(bind.params_raw()[2].data.as_ref(), b"0");
241+
} else {
242+
panic!("expected Bind");
243+
}
244+
}

0 commit comments

Comments
 (0)