Skip to content

Commit 08d947e

Browse files
authored
Merge branch 'master' into refactor/auth-manager-design-191
2 parents 690a499 + 0f69e19 commit 08d947e

File tree

3 files changed

+84
-60
lines changed

3 files changed

+84
-60
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ impl SimpleQueryHandler for DfSessionService {
139139

140140
let mut results = vec![];
141141
'stmt: for statement in statements {
142-
// TODO: improve statement check by using statement directly
143142
let query = statement.to_string();
144-
let query_lower = query.to_lowercase().trim().to_string();
145143

146144
// Call query hooks with the parsed statement
147145
for hook in &self.query_hooks {
@@ -179,7 +177,7 @@ impl SimpleQueryHandler for DfSessionService {
179177
}
180178
};
181179

182-
if query_lower.starts_with("insert into") {
180+
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
183181
let resp = map_rows_affected_for_insert(&df).await?;
184182
results.push(resp);
185183
} else {
@@ -265,13 +263,7 @@ impl ExtendedQueryHandler for DfSessionService {
265263
where
266264
C: ClientInfo + Unpin + Send + Sync,
267265
{
268-
let query = portal
269-
.statement
270-
.statement
271-
.0
272-
.to_lowercase()
273-
.trim()
274-
.to_string();
266+
let query = &portal.statement.statement.0;
275267
log::debug!("Received execute extended query: {query}"); // Log for debugging
276268

277269
// Check query hooks first
@@ -302,7 +294,7 @@ impl ExtendedQueryHandler for DfSessionService {
302294
}
303295
}
304296

305-
if let (_, Some((_, plan))) = &portal.statement.statement {
297+
if let (_, Some((statement, plan))) = &portal.statement.statement {
306298
let param_types = plan
307299
.get_parameter_types()
308300
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -345,7 +337,7 @@ impl ExtendedQueryHandler for DfSessionService {
345337
}
346338
};
347339

348-
if query.starts_with("insert into") {
340+
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
349341
let resp = map_rows_affected_for_insert(&dataframe).await?;
350342

351343
Ok(resp)

datafusion-postgres/src/hooks/permissions.rs

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ impl PermissionsHook {
2323
PermissionsHook { auth_manager }
2424
}
2525

26-
/// Check if the current user has permission to execute a query
27-
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
26+
/// Check if the current user has permission to execute a statement
27+
async fn check_statement_permission<C>(
28+
&self,
29+
client: &C,
30+
statement: &Statement,
31+
) -> PgWireResult<()>
2832
where
2933
C: ClientInfo + ?Sized,
3034
{
@@ -35,29 +39,19 @@ impl PermissionsHook {
3539
.map(|s| s.as_str())
3640
.unwrap_or("anonymous");
3741

38-
// Parse query to determine required permissions
39-
let query_lower = query.to_lowercase();
40-
let query_trimmed = query_lower.trim();
41-
42-
let (required_permission, resource) = if query_trimmed.starts_with("select") {
43-
(Permission::Select, ResourceType::All)
44-
} else if query_trimmed.starts_with("insert") {
45-
(Permission::Insert, ResourceType::All)
46-
} else if query_trimmed.starts_with("update") {
47-
(Permission::Update, ResourceType::All)
48-
} else if query_trimmed.starts_with("delete") {
49-
(Permission::Delete, ResourceType::All)
50-
} else if query_trimmed.starts_with("create table")
51-
|| query_trimmed.starts_with("create view")
52-
{
53-
(Permission::Create, ResourceType::All)
54-
} else if query_trimmed.starts_with("drop") {
55-
(Permission::Drop, ResourceType::All)
56-
} else if query_trimmed.starts_with("alter") {
57-
(Permission::Alter, ResourceType::All)
58-
} else {
59-
// For other queries (SHOW, EXPLAIN, etc.), allow all users
60-
return Ok(());
42+
// Determine required permissions based on Statement type
43+
let (required_permission, resource) = match statement {
44+
Statement::Query(_) => (Permission::Select, ResourceType::All),
45+
Statement::Insert(_) => (Permission::Insert, ResourceType::All),
46+
Statement::Update { .. } => (Permission::Update, ResourceType::All),
47+
Statement::Delete(_) => (Permission::Delete, ResourceType::All),
48+
Statement::CreateTable { .. } | Statement::CreateView { .. } => {
49+
(Permission::Create, ResourceType::All)
50+
}
51+
Statement::Drop { .. } => (Permission::Drop, ResourceType::All),
52+
Statement::AlterTable { .. } => (Permission::Alter, ResourceType::All),
53+
// For other statements (SET, SHOW, EXPLAIN, transactions, etc.), allow all users
54+
_ => return Ok(()),
6155
};
6256

6357
// Check permission
@@ -78,6 +72,21 @@ impl PermissionsHook {
7872

7973
Ok(())
8074
}
75+
76+
/// Check if a statement should skip permission checks
77+
fn should_skip_permission_check(statement: &Statement) -> bool {
78+
matches!(
79+
statement,
80+
Statement::Set { .. }
81+
| Statement::ShowVariable { .. }
82+
| Statement::ShowStatus { .. }
83+
| Statement::StartTransaction { .. }
84+
| Statement::Commit { .. }
85+
| Statement::Rollback { .. }
86+
| Statement::Savepoint { .. }
87+
| Statement::ReleaseSavepoint { .. }
88+
)
89+
}
8190
}
8291

8392
#[async_trait]
@@ -89,22 +98,13 @@ impl QueryHook for PermissionsHook {
8998
_session_context: &SessionContext,
9099
client: &mut (dyn ClientInfo + Send + Sync),
91100
) -> Option<PgWireResult<Response>> {
92-
let query_lower = statement.to_string().to_lowercase();
93-
94-
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
95-
if !query_lower.starts_with("set")
96-
&& !query_lower.starts_with("begin")
97-
&& !query_lower.starts_with("commit")
98-
&& !query_lower.starts_with("rollback")
99-
&& !query_lower.starts_with("start")
100-
&& !query_lower.starts_with("end")
101-
&& !query_lower.starts_with("abort")
102-
&& !query_lower.starts_with("show")
103-
{
104-
let res = self.check_query_permission(&*client, &query_lower).await;
105-
if let Err(e) = res {
106-
return Some(Err(e));
107-
}
101+
if Self::should_skip_permission_check(statement) {
102+
return None;
103+
}
104+
105+
// Check permissions for other statements
106+
if let Err(e) = self.check_statement_permission(&*client, statement).await {
107+
return Some(Err(e));
108108
}
109109

110110
None
@@ -127,15 +127,15 @@ impl QueryHook for PermissionsHook {
127127
_session_context: &SessionContext,
128128
client: &mut (dyn ClientInfo + Send + Sync),
129129
) -> Option<PgWireResult<Response>> {
130-
let query = statement.to_string().to_lowercase();
130+
if Self::should_skip_permission_check(statement) {
131+
return None;
132+
}
131133

132-
// Check permissions for the query (skip for SET and SHOW statements)
133-
if !query.starts_with("set") && !query.starts_with("show") {
134-
let res = self.check_query_permission(&*client, &query).await;
135-
if let Err(e) = res {
136-
return Some(Err(e));
137-
}
134+
// Check permissions for other statements
135+
if let Err(e) = self.check_statement_permission(&*client, statement).await {
136+
return Some(Err(e));
138137
}
138+
139139
None
140140
}
141141
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use pgwire::api::query::SimpleQueryHandler;
2+
3+
use datafusion_postgres::testing::*;
4+
5+
// pgAdmin startup queries from issue #178
6+
// https://github.com/datafusion-contrib/datafusion-postgres/issues/178
7+
const PGADMIN_QUERIES: &[&str] = &[
8+
// Basic version query (fixed by #179)
9+
"SELECT version()",
10+
// Query to check for BDR extension and replication slots
11+
r#"SELECT CASE
12+
WHEN (SELECT count(extname) FROM pg_catalog.pg_extension WHERE extname='bdr') > 0
13+
THEN 'pgd'
14+
WHEN (SELECT COUNT(*) FROM pg_replication_slots) > 0
15+
THEN 'log'
16+
ELSE NULL
17+
END as type"#,
18+
];
19+
20+
#[tokio::test]
21+
pub async fn test_pgadmin_startup_sql() {
22+
let service = setup_handlers();
23+
let mut client = MockClient::new();
24+
25+
for query in PGADMIN_QUERIES {
26+
SimpleQueryHandler::do_query(&service, &mut client, query)
27+
.await
28+
.unwrap_or_else(|e| {
29+
panic!("failed to run sql:\n--------------\n{query}\n--------------\n{e}")
30+
});
31+
}
32+
}

0 commit comments

Comments
 (0)