Skip to content

Commit ad09807

Browse files
committed
refactor: optimize project context loading and request handling
1 parent 2da6d06 commit ad09807

File tree

1 file changed

+80
-136
lines changed

1 file changed

+80
-136
lines changed

src/handler/mod.rs

Lines changed: 80 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ use crate::models::{DataSource, Error, ErrorTracking};
1313
use crate::tinybird::{ErrorRow, ErrorTrackingRow, EventRow};
1414
use axum::Json;
1515
use axum::http::{HeaderMap, StatusCode};
16-
use futures::TryStreamExt;
17-
use serde::Deserialize;
16+
use serde::{Deserialize, Serialize};
1817
use serde_json::Value;
1918
use sqlx::Row;
2019
use std::borrow::Cow;
@@ -79,20 +78,26 @@ pub fn error_response(status: StatusCode, message: &str) -> HandlerResponse {
7978
(status, Json(serde_json::json!({ "error": message })))
8079
}
8180

81+
#[derive(Serialize)]
82+
struct SimpleResponse {
83+
status: &'static str,
84+
}
85+
86+
#[derive(Serialize)]
87+
struct WarningResponse {
88+
warnings: HashMap<String, String>,
89+
}
90+
8291
pub fn success_response(warnings: HashMap<String, String>) -> HandlerResponse {
8392
if warnings.is_empty() {
8493
(
8594
StatusCode::OK,
86-
Json(serde_json::json!({ "status": "success" })),
95+
Json(serde_json::to_value(SimpleResponse { status: "success" }).unwrap()),
8796
)
8897
} else {
89-
let warnings_obj: serde_json::Map<String, Value> = warnings
90-
.into_iter()
91-
.map(|(k, v)| (k, Value::String(v)))
92-
.collect();
9398
(
9499
StatusCode::OK,
95-
Json(serde_json::json!({ "warnings": warnings_obj })),
100+
Json(serde_json::to_value(WarningResponse { warnings }).unwrap()),
96101
)
97102
}
98103
}
@@ -116,133 +121,72 @@ pub async fn load_project_context(
116121
pool: &sqlx::PgPool,
117122
token: &str,
118123
) -> Result<ProjectContext, HandlerResponse> {
119-
let mut rows = sqlx::query(
120-
"
121-
SELECT
122-
p.id,
123-
p.owner_id,
124-
p.domain,
125-
d.reference_id,
126-
d.name,
127-
d.data_type::text AS data_type,
128-
p.error_tracking_enabled,
129-
d.regex,
130-
d.allow_negative,
131-
d.allow_float,
132-
d.min_value,
133-
d.max_value,
134-
d.is_array,
135-
CASE
136-
WHEN u.id IS NOT NULL THEN p.owner_id
137-
WHEN o.id IS NOT NULL THEN m.user_id
138-
ELSE p.owner_id
139-
END AS billing_customer_id,
140-
o.id AS organization_id
124+
let rows = sqlx::query(
125+
r#"
126+
SELECT p.id, p.owner_id, p.domain, p.error_tracking_enabled, o.id AS organization_id,
127+
d.reference_id, d.name, d.data_type::text, d.regex, d.allow_negative,
128+
d.allow_float, d.min_value, d.max_value, d.is_array
141129
FROM project p
142130
LEFT JOIN data_sources d ON d.project_id = p.id
143-
LEFT JOIN \"user\" u ON u.id = p.owner_id
144131
LEFT JOIN organization o ON o.id = p.owner_id
145-
LEFT JOIN member m ON m.organization_id = o.id AND m.role = 'owner'
146132
WHERE p.token = $1
147-
",
133+
"#,
148134
)
149135
.bind(token)
150-
.fetch(pool);
151-
152-
let first_row = rows
153-
.try_next()
154-
.await
155-
.map_err(|_| error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?;
136+
.fetch_all(pool)
137+
.await
138+
.map_err(|_| error_response(StatusCode::INTERNAL_SERVER_ERROR, "DB Error"))?;
156139

157-
let Some(first) = first_row else {
140+
if rows.is_empty() {
158141
return Err(error_response(StatusCode::UNAUTHORIZED, "Unauthorized"));
159-
};
160-
161-
let project_id: Uuid = first
162-
.try_get("id")
163-
.map_err(|_| error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?;
164-
let owner_id: String = first
165-
.try_get("billing_customer_id")
166-
.or_else(|_| first.try_get("owner_id"))
167-
.map_err(|_| error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?;
168-
let organization_id: Option<String> = first.try_get("organization_id").unwrap_or(None);
169-
let domain: Option<String> = first.try_get("domain").unwrap_or(None);
170-
let error_tracking_enabled: bool = first.try_get("error_tracking_enabled").unwrap_or(false);
171-
172-
let mut datasources: HashMap<String, DataSource> = HashMap::new();
173-
174-
let process_row = |row: &sqlx::postgres::PgRow| -> Option<DataSource> {
175-
let reference_id: String = row
176-
.try_get::<Option<String>, _>("reference_id")
177-
.ok()
178-
.flatten()
179-
.unwrap_or_default();
180-
if reference_id.is_empty() {
181-
return None;
182-
}
183-
Some(DataSource {
184-
reference_id,
185-
name: row
186-
.try_get::<Option<String>, _>("name")
187-
.ok()
188-
.flatten()
189-
.unwrap_or_default(),
190-
data_type: row.try_get::<String, _>("data_type").unwrap_or_default(),
191-
regex: row.try_get::<Option<String>, _>("regex").ok().flatten(),
192-
allow_negative: row
193-
.try_get::<Option<bool>, _>("allow_negative")
194-
.ok()
195-
.flatten(),
196-
allow_float: row.try_get::<Option<bool>, _>("allow_float").ok().flatten(),
197-
min_value: row.try_get::<Option<f64>, _>("min_value").ok().flatten(),
198-
max_value: row.try_get::<Option<f64>, _>("max_value").ok().flatten(),
199-
is_array: row
200-
.try_get::<Option<bool>, _>("is_array")
201-
.ok()
202-
.flatten()
203-
.unwrap_or(false),
204-
})
205-
};
206-
207-
if let Some(ds) = process_row(&first) {
208-
datasources.insert(ds.reference_id.clone(), ds);
209142
}
210143

211-
while let Some(row) = rows
212-
.try_next()
213-
.await
214-
.map_err(|_| error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?
215-
{
216-
if let Some(ds) = process_row(&row) {
217-
datasources.insert(ds.reference_id.clone(), ds);
144+
let first = &rows[0];
145+
let mut datasources = HashMap::with_capacity(rows.len());
146+
147+
for row in &rows {
148+
if let Ok(Some(ref_id)) = row.try_get::<Option<String>, _>("reference_id") {
149+
datasources.insert(
150+
ref_id.clone(),
151+
DataSource {
152+
reference_id: ref_id,
153+
name: row
154+
.try_get::<Option<String>, _>("name")
155+
.ok()
156+
.flatten()
157+
.unwrap_or_default(),
158+
data_type: row.try_get::<String, _>("data_type").unwrap_or_default(),
159+
regex: row.try_get("regex").ok(),
160+
allow_negative: row.try_get("allow_negative").ok(),
161+
allow_float: row.try_get("allow_float").ok(),
162+
min_value: row.try_get("min_value").ok(),
163+
max_value: row.try_get("max_value").ok(),
164+
is_array: row
165+
.try_get::<Option<bool>, _>("is_array")
166+
.ok()
167+
.flatten()
168+
.unwrap_or(false),
169+
},
170+
);
218171
}
219172
}
220173

221-
let ip_rules: Vec<IpRule> =
222-
sqlx::query("SELECT ip_address, allowed FROM ip_addresses WHERE project_id = $1")
223-
.bind(project_id)
224-
.fetch_all(pool)
225-
.await
226-
.map_err(|_| {
227-
error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error")
228-
})?
229-
.into_iter()
230-
.map(|row| IpRule {
231-
ip_address: row.try_get("ip_address").unwrap_or_default(),
232-
allowed: row
233-
.try_get::<Option<bool>, _>("allowed")
234-
.unwrap_or(Some(true))
235-
.unwrap_or(true),
236-
})
237-
.collect();
174+
let ip_rules = sqlx::query_as!(
175+
IpRule,
176+
"SELECT ip_address, allowed FROM ip_addresses WHERE project_id = $1",
177+
first.get::<Uuid, _>("id")
178+
)
179+
.fetch_all(pool)
180+
.await
181+
.unwrap_or_default();
238182

239183
Ok(ProjectContext {
240-
project_id,
241-
owner_id,
242-
organization_id,
243-
domain,
184+
project_id: first.get("id"),
185+
owner_id: first.get("owner_id"),
186+
organization_id: first.get("organization_id"),
187+
domain: first.get("domain"),
244188
datasources,
245-
error_tracking_enabled,
189+
error_tracking_enabled: first.get("error_tracking_enabled"),
246190
ip_rules,
247191
})
248192
}
@@ -308,25 +252,25 @@ pub fn check_ip_allowed(ip_rules: &[IpRule], client_ip: &str) -> Result<(), &'st
308252
return Ok(());
309253
}
310254

311-
let has_whitelist = ip_rules.iter().any(|r| r.allowed);
255+
let mut has_whitelist = false;
256+
let mut allowed_by_whitelist = false;
312257

313-
if has_whitelist {
314-
if ip_rules
315-
.iter()
316-
.any(|r| r.allowed && r.ip_address == client_ip)
317-
{
318-
Ok(())
319-
} else {
320-
Err("IP address not allowed")
258+
for rule in ip_rules {
259+
if rule.allowed {
260+
has_whitelist = true;
261+
if rule.ip_address == client_ip {
262+
allowed_by_whitelist = true;
263+
}
264+
} else if rule.ip_address == client_ip {
265+
return Err("IP address blocked");
321266
}
322-
} else if ip_rules
323-
.iter()
324-
.any(|r| !r.allowed && r.ip_address == client_ip)
325-
{
326-
Err("IP address blocked")
327-
} else {
328-
Ok(())
329267
}
268+
269+
if has_whitelist && !allowed_by_whitelist {
270+
return Err("IP address not allowed");
271+
}
272+
273+
Ok(())
330274
}
331275

332276
pub fn enrich_data_with_country(data: &mut HashMap<String, Value>, headers: &HeaderMap) {
@@ -464,7 +408,7 @@ async fn process_collect_request(
464408
.map_err(|_| "Unauthorized or database error")?;
465409

466410
let req: crate::models::Request =
467-
serde_json::from_slice(&request.body).map_err(|_| "Invalid JSON".to_string())?;
411+
serde_json::from_slice(&request.body).map_err(|_| "Invalid JSON")?;
468412

469413
let server_id = req
470414
.id

0 commit comments

Comments
 (0)