Skip to content

Commit 5d13e11

Browse files
authored
Merge pull request #73 from tankyleo/default-db-config
Add default database settting, and env var overrides
2 parents b828ef6 + f153b49 commit 5d13e11

File tree

5 files changed

+266
-165
lines changed

5 files changed

+266
-165
lines changed

rust/auth-impls/src/jwt.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ const BEARER_PREFIX: &str = "Bearer ";
3434
impl JWTAuthorizer {
3535
/// Creates a new instance of [`JWTAuthorizer`], fails on failure to parse the PEM formatted RSA public key
3636
pub async fn new(rsa_pem: &str) -> Result<Self, String> {
37-
let jwt_issuer_key =
38-
DecodingKey::from_rsa_pem(rsa_pem.as_bytes()).map_err(|e| e.to_string())?;
37+
let jwt_issuer_key = DecodingKey::from_rsa_pem(rsa_pem.as_bytes())
38+
.map_err(|e| format!("Failed to parse the PEM formatted RSA public key: {}", e))?;
3939
Ok(Self { jwt_issuer_key })
4040
}
4141
}

rust/impls/src/postgres_store.rs

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,16 @@ pub type PostgresPlaintextBackend = PostgresBackend<NoTls>;
6464
/// A postgres backend with TLS connections to the database
6565
pub type PostgresTlsBackend = PostgresBackend<MakeTlsConnector>;
6666

67-
async fn make_postgres_db_connection<T>(postgres_endpoint: &str, tls: T) -> Result<Client, Error>
67+
async fn make_db_connection<T>(
68+
postgres_endpoint: &str, db_name: &str, tls: T,
69+
) -> Result<Client, Error>
6870
where
6971
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
7072
T::Stream: Send + Sync,
7173
T::TlsConnect: Send,
7274
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
7375
{
74-
let dsn = format!("{}/{}", postgres_endpoint, "postgres");
76+
let dsn = format!("{}/{}", postgres_endpoint, db_name);
7577
let (client, connection) = tokio_postgres::connect(&dsn, tls)
7678
.await
7779
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
@@ -84,16 +86,16 @@ where
8486
Ok(client)
8587
}
8688

87-
async fn initialize_vss_database<T>(
88-
postgres_endpoint: &str, db_name: &str, tls: T,
89+
async fn create_database<T>(
90+
postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T,
8991
) -> Result<(), Error>
9092
where
9193
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
9294
T::Stream: Send + Sync,
9395
T::TlsConnect: Send,
9496
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
9597
{
96-
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
98+
let client = make_db_connection(postgres_endpoint, default_db, tls).await?;
9799

98100
let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
99101
Error::new(
@@ -113,14 +115,16 @@ where
113115
}
114116

115117
#[cfg(test)]
116-
async fn drop_database<T>(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<(), Error>
118+
async fn drop_database<T>(
119+
postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T,
120+
) -> Result<(), Error>
117121
where
118122
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
119123
T::Stream: Send + Sync,
120124
T::TlsConnect: Send,
121125
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
122126
{
123-
let client = make_postgres_db_connection(&postgres_endpoint, tls).await?;
127+
let client = make_db_connection(postgres_endpoint, default_db, tls).await?;
124128

125129
let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name);
126130
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
@@ -133,25 +137,38 @@ where
133137

134138
impl PostgresPlaintextBackend {
135139
/// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136-
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
137-
PostgresBackend::new_internal(postgres_endpoint, db_name, NoTls).await
140+
pub async fn new(
141+
postgres_endpoint: &str, default_db: &str, vss_db: &str,
142+
) -> Result<Self, Error> {
143+
PostgresBackend::new_internal(postgres_endpoint, default_db, vss_db, NoTls).await
138144
}
139145
}
140146

141147
impl PostgresTlsBackend {
142148
/// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
143149
pub async fn new(
144-
postgres_endpoint: &str, db_name: &str, additional_certificate: Option<Certificate>,
150+
postgres_endpoint: &str, default_db: &str, vss_db: &str, crt_pem: Option<&str>,
145151
) -> Result<Self, Error> {
146152
let mut builder = TlsConnector::builder();
147-
if let Some(cert) = additional_certificate {
148-
builder.add_root_certificate(cert);
153+
if let Some(pem) = crt_pem {
154+
let crt = Certificate::from_pem(pem.as_bytes()).map_err(|e| {
155+
Error::new(
156+
ErrorKind::Other,
157+
format!("Failed to parse the PEM formatted certificate: {}", e),
158+
)
159+
})?;
160+
builder.add_root_certificate(crt);
149161
}
150162
let connector = builder.build().map_err(|e| {
151163
Error::new(ErrorKind::Other, format!("Error building tls connector: {}", e))
152164
})?;
153-
PostgresBackend::new_internal(postgres_endpoint, db_name, MakeTlsConnector::new(connector))
154-
.await
165+
PostgresBackend::new_internal(
166+
postgres_endpoint,
167+
default_db,
168+
vss_db,
169+
MakeTlsConnector::new(connector),
170+
)
171+
.await
155172
}
156173
}
157174

@@ -162,9 +179,11 @@ where
162179
T::TlsConnect: Send,
163180
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
164181
{
165-
async fn new_internal(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<Self, Error> {
166-
initialize_vss_database(postgres_endpoint, db_name, tls.clone()).await?;
167-
let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
182+
async fn new_internal(
183+
postgres_endpoint: &str, default_db: &str, vss_db: &str, tls: T,
184+
) -> Result<Self, Error> {
185+
create_database(postgres_endpoint, default_db, vss_db, tls.clone()).await?;
186+
let vss_dsn = format!("{}/{}", postgres_endpoint, vss_db);
168187
let manager =
169188
PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| {
170189
Error::new(
@@ -649,24 +668,27 @@ mod tests {
649668
use tokio_postgres::NoTls;
650669

651670
const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
671+
const DEFAULT_DB: &str = "postgres";
652672
const MIGRATIONS_START: usize = 0;
653673
const MIGRATIONS_END: usize = MIGRATIONS.len();
654674

655675
static START: OnceCell<()> = OnceCell::const_new();
656676

657677
define_kv_store_tests!(PostgresKvStoreTest, PostgresPlaintextBackend, {
658-
let db_name = "postgres_kv_store_tests";
678+
let vss_db = "postgres_kv_store_tests";
659679
START
660680
.get_or_init(|| async {
661-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
662-
let store =
663-
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
681+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
682+
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db)
683+
.await
684+
.unwrap();
664685
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
665686
assert_eq!(start, MIGRATIONS_START);
666687
assert_eq!(end, MIGRATIONS_END);
667688
})
668689
.await;
669-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
690+
let store =
691+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
670692
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
671693
assert_eq!(start, MIGRATIONS_END);
672694
assert_eq!(end, MIGRATIONS_END);
@@ -678,36 +700,40 @@ mod tests {
678700
#[tokio::test]
679701
#[should_panic(expected = "We do not allow downgrades")]
680702
async fn panic_on_downgrade() {
681-
let db_name = "panic_on_downgrade_test";
682-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
703+
let vss_db = "panic_on_downgrade_test";
704+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
683705
{
684706
let mut migrations = MIGRATIONS.to_vec();
685707
migrations.push(DUMMY_MIGRATION);
686-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
708+
let store =
709+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
687710
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
688711
assert_eq!(start, MIGRATIONS_START);
689712
assert_eq!(end, MIGRATIONS_END + 1);
690713
};
691714
{
692-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
715+
let store =
716+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
693717
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
694718
};
695719
}
696720

697721
#[tokio::test]
698722
async fn new_migrations_increments_upgrades() {
699-
let db_name = "new_migrations_increments_upgrades_test";
700-
let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await;
723+
let vss_db = "new_migrations_increments_upgrades_test";
724+
let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await;
701725
{
702-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
726+
let store =
727+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
703728
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
704729
assert_eq!(start, MIGRATIONS_START);
705730
assert_eq!(end, MIGRATIONS_END);
706731
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
707732
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
708733
};
709734
{
710-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
735+
let store =
736+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
711737
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
712738
assert_eq!(start, MIGRATIONS_END);
713739
assert_eq!(end, MIGRATIONS_END);
@@ -718,7 +744,8 @@ mod tests {
718744
let mut migrations = MIGRATIONS.to_vec();
719745
migrations.push(DUMMY_MIGRATION);
720746
{
721-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
747+
let store =
748+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
722749
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
723750
assert_eq!(start, MIGRATIONS_END);
724751
assert_eq!(end, MIGRATIONS_END + 1);
@@ -729,7 +756,8 @@ mod tests {
729756
migrations.push(DUMMY_MIGRATION);
730757
migrations.push(DUMMY_MIGRATION);
731758
{
732-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
759+
let store =
760+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
733761
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
734762
assert_eq!(start, MIGRATIONS_END + 1);
735763
assert_eq!(end, MIGRATIONS_END + 3);
@@ -741,13 +769,14 @@ mod tests {
741769
};
742770

743771
{
744-
let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
772+
let store =
773+
PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap();
745774
let list = store.get_upgrades_list().await;
746775
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
747776
let version = store.get_schema_version().await;
748777
assert_eq!(version, MIGRATIONS_END + 3);
749778
}
750779

751-
drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await.unwrap();
780+
drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await.unwrap();
752781
}
753782
}

0 commit comments

Comments
 (0)