11use crate :: frontend:: router:: parser:: rewrite:: statement:: {
22 offset:: OffsetPlan , plan:: RewriteResult ,
33} ;
4+ use crate :: frontend:: router:: parser:: route:: { Route , Shard , ShardWithPriority } ;
45use crate :: frontend:: router:: parser:: Limit ;
56
67use super :: prelude:: * ;
@@ -21,6 +22,16 @@ async fn run_test(messages: Vec<ProtocolMessage>) -> Option<OffsetPlan> {
2122 }
2223}
2324
25+ fn cross_shard_route ( ) -> Route {
26+ Route :: select (
27+ ShardWithPriority :: new_table ( Shard :: All ) ,
28+ vec ! [ ] ,
29+ Default :: default ( ) ,
30+ Limit :: default ( ) ,
31+ None ,
32+ )
33+ }
34+
2435#[ tokio:: test]
2536async fn test_offset_limit_literals ( ) {
2637 let offset = run_test ( vec ! [ ProtocolMessage :: Query ( Query :: new(
@@ -104,3 +115,130 @@ async fn test_offset_limit_no_select() {
104115
105116 assert ! ( offset. is_none( ) ) ;
106117}
118+
119+ #[ tokio:: test]
120+ async fn test_offset_with_unique_id_simple ( ) {
121+ unsafe {
122+ std:: env:: set_var ( "NODE_ID" , "pgdog-1" ) ;
123+ }
124+ let sql = "SELECT pgdog.unique_id() FROM test LIMIT 10 OFFSET 5" ;
125+ let mut client = test_sharded_client ( ) ;
126+ client. client_request = ClientRequest :: from ( vec ! [ ProtocolMessage :: Query ( Query :: new( sql) ) ] ) ;
127+
128+ let mut engine = QueryEngine :: from_client ( & client) . unwrap ( ) ;
129+ let mut context = QueryEngineContext :: new ( & mut client) ;
130+
131+ engine. parse_and_rewrite ( & mut context) . await . unwrap ( ) ;
132+
133+ // After parse_and_rewrite, the Query message should have unique_id replaced.
134+ let rewritten_sql = match & context. client_request . messages [ 0 ] {
135+ ProtocolMessage :: Query ( q) => q. query ( ) . to_owned ( ) ,
136+ _ => panic ! ( "expected Query" ) ,
137+ } ;
138+ assert ! (
139+ !rewritten_sql. contains( "pgdog.unique_id" ) ,
140+ "unique_id should be replaced: {rewritten_sql}"
141+ ) ;
142+ assert ! (
143+ rewritten_sql. contains( "::bigint" ) ,
144+ "should have bigint cast: {rewritten_sql}"
145+ ) ;
146+
147+ // apply_after_parser with a cross-shard route.
148+ context. client_request . route = Some ( cross_shard_route ( ) ) ;
149+ context
150+ . rewrite_result
151+ . as_ref ( )
152+ . unwrap ( )
153+ . apply_after_parser ( context. client_request )
154+ . unwrap ( ) ;
155+
156+ let final_sql = match & context. client_request . messages [ 0 ] {
157+ ProtocolMessage :: Query ( q) => q. query ( ) . to_owned ( ) ,
158+ _ => panic ! ( "expected Query" ) ,
159+ } ;
160+
161+ // unique_id rewrite must survive.
162+ assert ! (
163+ !final_sql. contains( "pgdog.unique_id" ) ,
164+ "unique_id rewrite must survive apply_after_parser: {final_sql}"
165+ ) ;
166+ assert ! (
167+ final_sql. contains( "::bigint" ) ,
168+ "bigint cast must survive: {final_sql}"
169+ ) ;
170+ // LIMIT/OFFSET must be rewritten for cross-shard.
171+ assert ! (
172+ final_sql. contains( "LIMIT 15" ) ,
173+ "LIMIT should be 10+5=15: {final_sql}"
174+ ) ;
175+ assert ! (
176+ final_sql. contains( "OFFSET 0" ) ,
177+ "OFFSET should be 0: {final_sql}"
178+ ) ;
179+ }
180+
181+ #[ tokio:: test]
182+ async fn test_offset_with_unique_id_extended ( ) {
183+ unsafe {
184+ std:: env:: set_var ( "NODE_ID" , "pgdog-1" ) ;
185+ }
186+ let sql = "SELECT pgdog.unique_id(), $1 FROM test LIMIT $2 OFFSET $3" ;
187+ let mut client = test_sharded_client ( ) ;
188+ client. client_request = ClientRequest :: from ( vec ! [
189+ ProtocolMessage :: Parse ( Parse :: new_anonymous( sql) ) ,
190+ ProtocolMessage :: Bind ( Bind :: new_params(
191+ "" ,
192+ & [
193+ Parameter :: new( b"hello" ) ,
194+ Parameter :: new( b"10" ) ,
195+ Parameter :: new( b"5" ) ,
196+ ] ,
197+ ) ) ,
198+ ProtocolMessage :: Execute ( Execute :: new( ) ) ,
199+ ProtocolMessage :: Sync ( Sync ) ,
200+ ] ) ;
201+
202+ let mut engine = QueryEngine :: from_client ( & client) . unwrap ( ) ;
203+ let mut context = QueryEngineContext :: new ( & mut client) ;
204+
205+ engine. parse_and_rewrite ( & mut context) . await . unwrap ( ) ;
206+
207+ // After parse_and_rewrite, Parse should have unique_id rewritten to $4::bigint.
208+ let rewritten_sql = match & context. client_request . messages [ 0 ] {
209+ ProtocolMessage :: Parse ( p) => p. query ( ) . to_owned ( ) ,
210+ _ => panic ! ( "expected Parse" ) ,
211+ } ;
212+ assert_eq ! (
213+ rewritten_sql,
214+ "SELECT $4::bigint, $1 FROM test LIMIT $2 OFFSET $3"
215+ ) ;
216+
217+ // apply_after_parser with cross-shard route should only rewrite Bind params.
218+ context. client_request . route = Some ( cross_shard_route ( ) ) ;
219+ context
220+ . rewrite_result
221+ . as_ref ( )
222+ . unwrap ( )
223+ . apply_after_parser ( context. client_request )
224+ . unwrap ( ) ;
225+
226+ // SQL unchanged (all limit/offset are params).
227+ let final_sql = match & context. client_request . messages [ 0 ] {
228+ ProtocolMessage :: Parse ( p) => p. query ( ) . to_owned ( ) ,
229+ _ => panic ! ( "expected Parse" ) ,
230+ } ;
231+ assert_eq ! (
232+ final_sql, "SELECT $4::bigint, $1 FROM test LIMIT $2 OFFSET $3" ,
233+ "SQL must be unchanged for all-param case"
234+ ) ;
235+
236+ // Bind params: $1=hello unchanged, $2=limit rewritten to 15, $3=offset rewritten to 0.
237+ if let ProtocolMessage :: Bind ( bind) = & context. client_request . messages [ 1 ] {
238+ assert_eq ! ( bind. params_raw( ) [ 0 ] . data. as_ref( ) , b"hello" ) ;
239+ assert_eq ! ( bind. params_raw( ) [ 1 ] . data. as_ref( ) , b"15" ) ;
240+ assert_eq ! ( bind. params_raw( ) [ 2 ] . data. as_ref( ) , b"0" ) ;
241+ } else {
242+ panic ! ( "expected Bind" ) ;
243+ }
244+ }
0 commit comments