@@ -13,8 +13,7 @@ use crate::models::{DataSource, Error, ErrorTracking};
1313use crate :: tinybird:: { ErrorRow , ErrorTrackingRow , EventRow } ;
1414use axum:: Json ;
1515use axum:: http:: { HeaderMap , StatusCode } ;
16- use futures:: TryStreamExt ;
17- use serde:: Deserialize ;
16+ use serde:: { Deserialize , Serialize } ;
1817use serde_json:: Value ;
1918use sqlx:: Row ;
2019use std:: borrow:: Cow ;
@@ -79,20 +78,26 @@ pub fn error_response(status: StatusCode, message: &str) -> HandlerResponse {
7978 ( status, Json ( serde_json:: json!( { "error" : message } ) ) )
8079}
8180
81+ #[ derive( Serialize ) ]
82+ struct SimpleResponse {
83+ status : & ' static str ,
84+ }
85+
86+ #[ derive( Serialize ) ]
87+ struct WarningResponse {
88+ warnings : HashMap < String , String > ,
89+ }
90+
8291pub fn success_response ( warnings : HashMap < String , String > ) -> HandlerResponse {
8392 if warnings. is_empty ( ) {
8493 (
8594 StatusCode :: OK ,
86- Json ( serde_json:: json! ( { " status" : "success" } ) ) ,
95+ Json ( serde_json:: to_value ( SimpleResponse { status : "success" } ) . unwrap ( ) ) ,
8796 )
8897 } else {
89- let warnings_obj: serde_json:: Map < String , Value > = warnings
90- . into_iter ( )
91- . map ( |( k, v) | ( k, Value :: String ( v) ) )
92- . collect ( ) ;
9398 (
9499 StatusCode :: OK ,
95- Json ( serde_json:: json! ( { " warnings" : warnings_obj } ) ) ,
100+ Json ( serde_json:: to_value ( WarningResponse { warnings } ) . unwrap ( ) ) ,
96101 )
97102 }
98103}
@@ -116,133 +121,72 @@ pub async fn load_project_context(
116121 pool : & sqlx:: PgPool ,
117122 token : & str ,
118123) -> Result < ProjectContext , HandlerResponse > {
119- let mut rows = sqlx:: query (
120- "
121- SELECT
122- p.id,
123- p.owner_id,
124- p.domain,
125- d.reference_id,
126- d.name,
127- d.data_type::text AS data_type,
128- p.error_tracking_enabled,
129- d.regex,
130- d.allow_negative,
131- d.allow_float,
132- d.min_value,
133- d.max_value,
134- d.is_array,
135- CASE
136- WHEN u.id IS NOT NULL THEN p.owner_id
137- WHEN o.id IS NOT NULL THEN m.user_id
138- ELSE p.owner_id
139- END AS billing_customer_id,
140- o.id AS organization_id
124+ let rows = sqlx:: query (
125+ r#"
126+ SELECT p.id, p.owner_id, p.domain, p.error_tracking_enabled, o.id AS organization_id,
127+ d.reference_id, d.name, d.data_type::text, d.regex, d.allow_negative,
128+ d.allow_float, d.min_value, d.max_value, d.is_array
141129 FROM project p
142130 LEFT JOIN data_sources d ON d.project_id = p.id
143- LEFT JOIN \" user\" u ON u.id = p.owner_id
144131 LEFT JOIN organization o ON o.id = p.owner_id
145- LEFT JOIN member m ON m.organization_id = o.id AND m.role = 'owner'
146132 WHERE p.token = $1
147- " ,
133+ "# ,
148134 )
149135 . bind ( token)
150- . fetch ( pool) ;
151-
152- let first_row = rows
153- . try_next ( )
154- . await
155- . map_err ( |_| error_response ( StatusCode :: INTERNAL_SERVER_ERROR , "Internal server error" ) ) ?;
136+ . fetch_all ( pool)
137+ . await
138+ . map_err ( |_| error_response ( StatusCode :: INTERNAL_SERVER_ERROR , "DB Error" ) ) ?;
156139
157- let Some ( first ) = first_row else {
140+ if rows . is_empty ( ) {
158141 return Err ( error_response ( StatusCode :: UNAUTHORIZED , "Unauthorized" ) ) ;
159- } ;
160-
161- let project_id: Uuid = first
162- . try_get ( "id" )
163- . map_err ( |_| error_response ( StatusCode :: INTERNAL_SERVER_ERROR , "Internal server error" ) ) ?;
164- let owner_id: String = first
165- . try_get ( "billing_customer_id" )
166- . or_else ( |_| first. try_get ( "owner_id" ) )
167- . map_err ( |_| error_response ( StatusCode :: INTERNAL_SERVER_ERROR , "Internal server error" ) ) ?;
168- let organization_id: Option < String > = first. try_get ( "organization_id" ) . unwrap_or ( None ) ;
169- let domain: Option < String > = first. try_get ( "domain" ) . unwrap_or ( None ) ;
170- let error_tracking_enabled: bool = first. try_get ( "error_tracking_enabled" ) . unwrap_or ( false ) ;
171-
172- let mut datasources: HashMap < String , DataSource > = HashMap :: new ( ) ;
173-
174- let process_row = |row : & sqlx:: postgres:: PgRow | -> Option < DataSource > {
175- let reference_id: String = row
176- . try_get :: < Option < String > , _ > ( "reference_id" )
177- . ok ( )
178- . flatten ( )
179- . unwrap_or_default ( ) ;
180- if reference_id. is_empty ( ) {
181- return None ;
182- }
183- Some ( DataSource {
184- reference_id,
185- name : row
186- . try_get :: < Option < String > , _ > ( "name" )
187- . ok ( )
188- . flatten ( )
189- . unwrap_or_default ( ) ,
190- data_type : row. try_get :: < String , _ > ( "data_type" ) . unwrap_or_default ( ) ,
191- regex : row. try_get :: < Option < String > , _ > ( "regex" ) . ok ( ) . flatten ( ) ,
192- allow_negative : row
193- . try_get :: < Option < bool > , _ > ( "allow_negative" )
194- . ok ( )
195- . flatten ( ) ,
196- allow_float : row. try_get :: < Option < bool > , _ > ( "allow_float" ) . ok ( ) . flatten ( ) ,
197- min_value : row. try_get :: < Option < f64 > , _ > ( "min_value" ) . ok ( ) . flatten ( ) ,
198- max_value : row. try_get :: < Option < f64 > , _ > ( "max_value" ) . ok ( ) . flatten ( ) ,
199- is_array : row
200- . try_get :: < Option < bool > , _ > ( "is_array" )
201- . ok ( )
202- . flatten ( )
203- . unwrap_or ( false ) ,
204- } )
205- } ;
206-
207- if let Some ( ds) = process_row ( & first) {
208- datasources. insert ( ds. reference_id . clone ( ) , ds) ;
209142 }
210143
211- while let Some ( row) = rows
212- . try_next ( )
213- . await
214- . map_err ( |_| error_response ( StatusCode :: INTERNAL_SERVER_ERROR , "Internal server error" ) ) ?
215- {
216- if let Some ( ds) = process_row ( & row) {
217- datasources. insert ( ds. reference_id . clone ( ) , ds) ;
144+ let first = & rows[ 0 ] ;
145+ let mut datasources = HashMap :: with_capacity ( rows. len ( ) ) ;
146+
147+ for row in & rows {
148+ if let Ok ( Some ( ref_id) ) = row. try_get :: < Option < String > , _ > ( "reference_id" ) {
149+ datasources. insert (
150+ ref_id. clone ( ) ,
151+ DataSource {
152+ reference_id : ref_id,
153+ name : row
154+ . try_get :: < Option < String > , _ > ( "name" )
155+ . ok ( )
156+ . flatten ( )
157+ . unwrap_or_default ( ) ,
158+ data_type : row. try_get :: < String , _ > ( "data_type" ) . unwrap_or_default ( ) ,
159+ regex : row. try_get ( "regex" ) . ok ( ) ,
160+ allow_negative : row. try_get ( "allow_negative" ) . ok ( ) ,
161+ allow_float : row. try_get ( "allow_float" ) . ok ( ) ,
162+ min_value : row. try_get ( "min_value" ) . ok ( ) ,
163+ max_value : row. try_get ( "max_value" ) . ok ( ) ,
164+ is_array : row
165+ . try_get :: < Option < bool > , _ > ( "is_array" )
166+ . ok ( )
167+ . flatten ( )
168+ . unwrap_or ( false ) ,
169+ } ,
170+ ) ;
218171 }
219172 }
220173
221- let ip_rules: Vec < IpRule > =
222- sqlx:: query ( "SELECT ip_address, allowed FROM ip_addresses WHERE project_id = $1" )
223- . bind ( project_id)
224- . fetch_all ( pool)
225- . await
226- . map_err ( |_| {
227- error_response ( StatusCode :: INTERNAL_SERVER_ERROR , "Internal server error" )
228- } ) ?
229- . into_iter ( )
230- . map ( |row| IpRule {
231- ip_address : row. try_get ( "ip_address" ) . unwrap_or_default ( ) ,
232- allowed : row
233- . try_get :: < Option < bool > , _ > ( "allowed" )
234- . unwrap_or ( Some ( true ) )
235- . unwrap_or ( true ) ,
236- } )
237- . collect ( ) ;
174+ let ip_rules = sqlx:: query_as!(
175+ IpRule ,
176+ "SELECT ip_address, allowed FROM ip_addresses WHERE project_id = $1" ,
177+ first. get:: <Uuid , _>( "id" )
178+ )
179+ . fetch_all ( pool)
180+ . await
181+ . unwrap_or_default ( ) ;
238182
239183 Ok ( ProjectContext {
240- project_id,
241- owner_id,
242- organization_id,
243- domain,
184+ project_id : first . get ( "id" ) ,
185+ owner_id : first . get ( "owner_id" ) ,
186+ organization_id : first . get ( "organization_id" ) ,
187+ domain : first . get ( "domain" ) ,
244188 datasources,
245- error_tracking_enabled,
189+ error_tracking_enabled : first . get ( "error_tracking_enabled" ) ,
246190 ip_rules,
247191 } )
248192}
@@ -308,25 +252,25 @@ pub fn check_ip_allowed(ip_rules: &[IpRule], client_ip: &str) -> Result<(), &'st
308252 return Ok ( ( ) ) ;
309253 }
310254
311- let has_whitelist = ip_rules. iter ( ) . any ( |r| r. allowed ) ;
255+ let mut has_whitelist = false ;
256+ let mut allowed_by_whitelist = false ;
312257
313- if has_whitelist {
314- if ip_rules
315- . iter ( )
316- . any ( |r| r . allowed && r . ip_address == client_ip)
317- {
318- Ok ( ( ) )
319- } else {
320- Err ( "IP address not allowed" )
258+ for rule in ip_rules {
259+ if rule . allowed {
260+ has_whitelist = true ;
261+ if rule . ip_address == client_ip {
262+ allowed_by_whitelist = true ;
263+ }
264+ } else if rule . ip_address == client_ip {
265+ return Err ( "IP address blocked" ) ;
321266 }
322- } else if ip_rules
323- . iter ( )
324- . any ( |r| !r. allowed && r. ip_address == client_ip)
325- {
326- Err ( "IP address blocked" )
327- } else {
328- Ok ( ( ) )
329267 }
268+
269+ if has_whitelist && !allowed_by_whitelist {
270+ return Err ( "IP address not allowed" ) ;
271+ }
272+
273+ Ok ( ( ) )
330274}
331275
332276pub fn enrich_data_with_country ( data : & mut HashMap < String , Value > , headers : & HeaderMap ) {
@@ -464,7 +408,7 @@ async fn process_collect_request(
464408 . map_err ( |_| "Unauthorized or database error" ) ?;
465409
466410 let req: crate :: models:: Request =
467- serde_json:: from_slice ( & request. body ) . map_err ( |_| "Invalid JSON" . to_string ( ) ) ?;
411+ serde_json:: from_slice ( & request. body ) . map_err ( |_| "Invalid JSON" ) ?;
468412
469413 let server_id = req
470414 . id
0 commit comments